13 import clientsubnetoption
19 from eqdnsmessage
import AssertEqualDNSMessageMixin
21 # Python2/3 compatibility hacks
23 from queue
import Queue
25 from Queue
import Queue
33 class DNSDistTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
35 Set up a dnsdist instance and responder threads.
36 Queries sent to dnsdist are relayed to the responder threads,
37 who reply with the response provided by the tests themselves
38 on a queue. Responder threads also queue the queries received
39 from dnsdist on a separate queue, allowing the tests to check
40 that the queries sent from dnsdist were as expected.
43 _dnsDistListeningAddr
= "127.0.0.1"
44 _testServerPort
= 5350
45 _toResponderQueue
= Queue()
46 _fromResponderQueue
= Queue()
48 _dnsdistStartupDelay
= 2.0
50 _responsesCounter
= {}
51 _config_template
= """
53 _config_params
= ['_testServerPort']
54 _acl
= ['127.0.0.1/32']
57 _healthCheckName
= 'a.root-servers.net.'
58 _healthCheckCounter
= 0
59 _answerUnexpected
= True
60 _checkConfigExpectedOutput
= None
63 def startResponders(cls
):
64 print("Launching responders..")
66 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
67 cls
._UDPResponder
.setDaemon(True)
68 cls
._UDPResponder
.start()
69 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
70 cls
._TCPResponder
.setDaemon(True)
71 cls
._TCPResponder
.start()
74 def startDNSDist(cls
):
75 print("Launching dnsdist..")
76 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
77 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
79 with
open(confFile
, 'w') as conf
:
80 conf
.write("-- Autogenerated by dnsdisttests.py\n")
81 conf
.write(cls
._config
_template
% params
)
83 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
84 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
86 dnsdistcmd
.extend(['--acl', acl
])
87 print(' '.join(dnsdistcmd
))
89 # validate config with --check-config, which sets client=true, possibly exposing bugs.
90 testcmd
= dnsdistcmd
+ ['--check-config']
92 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
93 except subprocess
.CalledProcessError
as exc
:
94 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
95 if cls
._checkConfigExpectedOutput
is not None:
96 expectedOutput
= cls
._checkConfigExpectedOutput
98 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
99 if output
!= expectedOutput
:
100 raise AssertionError('dnsdist --check-config failed: %s' % output
)
102 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
103 with
open(logFile
, 'w') as fdLog
:
104 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
106 if 'DNSDIST_FAST_TESTS' in os
.environ
:
109 delay
= cls
._dnsdistStartupDelay
113 if cls
._dnsdist
.poll() is not None:
115 sys
.exit(cls
._dnsdist
.returncode
)
118 def setUpSockets(cls
):
119 print("Setting up UDP socket..")
120 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
121 cls
._sock
.settimeout(2.0)
122 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
127 cls
.startResponders()
131 print("Launching tests..")
134 def tearDownClass(cls
):
135 if 'DNSDIST_FAST_TESTS' in os
.environ
:
140 cls
._dnsdist
.terminate()
141 if cls
._dnsdist
.poll() is None:
143 if cls
._dnsdist
.poll() is None:
148 def _ResponderIncrementCounter(cls
):
149 if threading
.currentThread().name
in cls
._responsesCounter
:
150 cls
._responsesCounter
[threading
.currentThread().name
] += 1
152 cls
._responsesCounter
[threading
.currentThread().name
] = 1
155 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
157 if len(request
.question
) != 1:
158 print("Skipping query with question count %d" % (len(request
.question
)))
160 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
162 cls
._healthCheckCounter
+= 1
163 response
= dns
.message
.make_response(request
)
165 cls
._ResponderIncrementCounter
()
166 if not fromQueue
.empty():
167 toQueue
.put(request
, True, cls
._queueTimeout
)
168 if synthesize
is None:
169 response
= fromQueue
.get(True, cls
._queueTimeout
)
171 response
= copy
.copy(response
)
172 response
.id = request
.id
175 if synthesize
is not None:
176 response
= dns
.message
.make_response(request
)
177 response
.set_rcode(synthesize
)
178 elif cls
._answerUnexpected
:
179 response
= dns
.message
.make_response(request
)
180 response
.set_rcode(dns
.rcode
.SERVFAIL
)
185 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
186 # trailingDataResponse=True means "ignore trailing data".
187 # Other values are either False (meaning "raise an exception")
188 # or are interpreted as a response RCODE for queries with trailing data.
189 # callback is invoked for every -even healthcheck ones- query and should return a raw response
190 ignoreTrailing
= trailingDataResponse
is True
192 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
193 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
194 sock
.bind(("127.0.0.1", port
))
196 data
, addr
= sock
.recvfrom(4096)
199 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
200 except dns
.message
.TrailingJunk
as e
:
201 if trailingDataResponse
is False or forceRcode
is True:
203 print("UDP query with trailing data, synthesizing response")
204 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
205 forceRcode
= trailingDataResponse
209 wire
= callback(request
)
211 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
213 wire
= response
.to_wire()
219 sock
.sendto(wire
, addr
)
220 sock
.settimeout(None)
224 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None):
225 # trailingDataResponse=True means "ignore trailing data".
226 # Other values are either False (meaning "raise an exception")
227 # or are interpreted as a response RCODE for queries with trailing data.
228 # callback is invoked for every -even healthcheck ones- query and should return a raw response
229 ignoreTrailing
= trailingDataResponse
is True
231 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
232 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
233 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
235 sock
.bind(("127.0.0.1", port
))
236 except socket
.error
as e
:
237 print("Error binding in the TCP responder: %s" % str(e
))
242 (conn
, _
) = sock
.accept()
249 (datalen
,) = struct
.unpack("!H", data
)
250 data
= conn
.recv(datalen
)
253 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
254 except dns
.message
.TrailingJunk
as e
:
255 if trailingDataResponse
is False or forceRcode
is True:
257 print("TCP query with trailing data, synthesizing response")
258 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
259 forceRcode
= trailingDataResponse
262 wire
= callback(request
)
264 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
266 wire
= response
.to_wire(max_size
=65535)
272 conn
.send(struct
.pack("!H", len(wire
)))
275 while multipleResponses
:
276 if fromQueue
.empty():
279 response
= fromQueue
.get(True, cls
._queueTimeout
)
283 response
= copy
.copy(response
)
284 response
.id = request
.id
285 wire
= response
.to_wire(max_size
=65535)
287 conn
.send(struct
.pack("!H", len(wire
)))
289 except socket
.error
as e
:
290 # some of the tests are going to close
291 # the connection on us, just deal with it
299 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
301 cls
._toResponderQueue
.put(response
, True, timeout
)
304 cls
._sock
.settimeout(timeout
)
308 query
= query
.to_wire()
309 cls
._sock
.send(query
)
310 data
= cls
._sock
.recv(4096)
311 except socket
.timeout
:
315 cls
._sock
.settimeout(None)
319 if useQueue
and not cls
._fromResponderQueue
.empty():
320 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
322 message
= dns
.message
.from_wire(data
)
323 return (receivedQuery
, message
)
326 def openTCPConnection(cls
, timeout
=None):
327 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
328 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
330 sock
.settimeout(timeout
)
332 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
336 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
337 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
338 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
340 sock
.settimeout(timeout
)
343 if hasattr(ssl
, 'create_default_context'):
344 sslctx
= ssl
.create_default_context(cafile
=caCert
)
345 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
347 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
349 sslsock
.connect(("127.0.0.1", port
))
353 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
355 wire
= query
.to_wire()
360 cls
._toResponderQueue
.put(response
, True, timeout
)
362 sock
.send(struct
.pack("!H", len(wire
)))
366 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
370 (datalen
,) = struct
.unpack("!H", data
)
371 data
= sock
.recv(datalen
)
373 message
= dns
.message
.from_wire(data
)
375 if useQueue
and not cls
._fromResponderQueue
.empty():
376 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
377 return (receivedQuery
, message
)
382 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
385 cls
._toResponderQueue
.put(response
, True, timeout
)
387 sock
= cls
.openTCPConnection(timeout
)
390 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
391 message
= cls
.recvTCPResponseOverConnection(sock
)
392 except socket
.timeout
as e
:
393 print("Timeout: %s" % (str(e
)))
394 except socket
.error
as e
:
395 print("Network error: %s" % (str(e
)))
400 if useQueue
and not cls
._fromResponderQueue
.empty():
401 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
403 return (receivedQuery
, message
)
406 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
408 for response
in responses
:
409 cls
._toResponderQueue
.put(response
, True, timeout
)
410 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
411 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
413 sock
.settimeout(timeout
)
415 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
420 wire
= query
.to_wire()
424 sock
.send(struct
.pack("!H", len(wire
)))
430 (datalen
,) = struct
.unpack("!H", data
)
431 data
= sock
.recv(datalen
)
432 messages
.append(dns
.message
.from_wire(data
))
434 except socket
.timeout
as e
:
435 print("Timeout: %s" % (str(e
)))
436 except socket
.error
as e
:
437 print("Network error: %s" % (str(e
)))
442 if useQueue
and not cls
._fromResponderQueue
.empty():
443 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
444 return (receivedQuery
, messages
)
447 # This function is called before every tests
449 # Clear the responses counters
450 for key
in self
._responsesCounter
:
451 self
._responsesCounter
[key
] = 0
453 self
._healthCheckCounter
= 0
455 # Make sure the queues are empty, in case
456 # a previous test failed
457 while not self
._toResponderQueue
.empty():
458 self
._toResponderQueue
.get(False)
460 while not self
._fromResponderQueue
.empty():
461 self
._fromResponderQueue
.get(False)
463 super(DNSDistTest
, self
).setUp()
466 def clearToResponderQueue(cls
):
467 while not cls
._toResponderQueue
.empty():
468 cls
._toResponderQueue
.get(False)
471 def clearFromResponderQueue(cls
):
472 while not cls
._fromResponderQueue
.empty():
473 cls
._fromResponderQueue
.get(False)
476 def clearResponderQueues(cls
):
477 cls
.clearToResponderQueue()
478 cls
.clearFromResponderQueue()
481 def generateConsoleKey():
482 return libnacl
.utils
.salsa_key()
485 def _encryptConsole(cls
, command
, nonce
):
486 command
= command
.encode('UTF-8')
487 if cls
._consoleKey
is None:
489 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
492 def _decryptConsole(cls
, command
, nonce
):
493 if cls
._consoleKey
is None:
496 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
497 return result
.decode('UTF-8')
500 def sendConsoleCommand(cls
, command
, timeout
=1.0):
501 ourNonce
= libnacl
.utils
.rand_nonce()
503 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
504 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
506 sock
.settimeout(timeout
)
508 sock
.connect(("127.0.0.1", cls
._consolePort
))
510 theirNonce
= sock
.recv(len(ourNonce
))
511 if len(theirNonce
) != len(ourNonce
):
512 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
513 if len(theirNonce
) == 0:
514 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
517 halfNonceSize
= int(len(ourNonce
) / 2)
518 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
519 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
520 msg
= cls
._encryptConsole
(command
, writingNonce
)
521 sock
.send(struct
.pack("!I", len(msg
)))
525 raise socket
.error("Got EOF while reading the response size")
527 (responseLen
,) = struct
.unpack("!I", data
)
528 data
= sock
.recv(responseLen
)
529 response
= cls
._decryptConsole
(data
, readingNonce
)
532 def compareOptions(self
, a
, b
):
533 self
.assertEquals(len(a
), len(b
))
534 for idx
in range(len(a
)):
535 self
.assertEquals(a
[idx
], b
[idx
])
537 def checkMessageNoEDNS(self
, expected
, received
):
538 self
.assertEquals(expected
, received
)
539 self
.assertEquals(received
.edns
, -1)
540 self
.assertEquals(len(received
.options
), 0)
542 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
543 self
.assertEquals(expected
, received
)
544 self
.assertEquals(received
.edns
, 0)
545 self
.assertEquals(expected
.payload
, received
.payload
)
547 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
548 self
.assertEquals(expected
, received
)
549 self
.assertEquals(received
.edns
, 0)
550 self
.assertEquals(expected
.payload
, received
.payload
)
551 self
.assertEquals(len(received
.options
), withCookies
)
553 for option
in received
.options
:
554 self
.assertEquals(option
.otype
, 10)
556 for option
in received
.options
:
557 self
.assertNotEquals(option
.otype
, 10)
559 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
560 self
.assertEquals(expected
, received
)
561 self
.assertEquals(received
.edns
, 0)
562 self
.assertEquals(expected
.payload
, received
.payload
)
563 self
.assertEquals(len(received
.options
), 1 + additionalOptions
)
565 for option
in received
.options
:
566 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
569 self
.assertNotEquals(additionalOptions
, 0)
571 self
.compareOptions(expected
.options
, received
.options
)
572 self
.assertTrue(hasECS
)
574 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
575 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
577 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
578 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
580 def checkQueryEDNSWithoutECS(self
, expected
, received
):
581 self
.checkMessageEDNSWithoutECS(expected
, received
)
583 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
584 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
586 def checkQueryNoEDNS(self
, expected
, received
):
587 self
.checkMessageNoEDNS(expected
, received
)
589 def checkResponseNoEDNS(self
, expected
, received
):
590 self
.checkMessageNoEDNS(expected
, received
)