16 import clientsubnetoption
29 from io
import BytesIO
31 from eqdnsmessage
import AssertEqualDNSMessageMixin
32 from proxyprotocol
import ProxyProtocol
34 # Python2/3 compatibility hacks
36 from queue
import Queue
38 from Queue
import Queue
46 class DNSDistTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
48 Set up a dnsdist instance and responder threads.
49 Queries sent to dnsdist are relayed to the responder threads,
50 who reply with the response provided by the tests themselves
51 on a queue. Responder threads also queue the queries received
52 from dnsdist on a separate queue, allowing the tests to check
53 that the queries sent from dnsdist were as expected.
56 _dnsDistListeningAddr
= "127.0.0.1"
57 _testServerPort
= 5350
58 _toResponderQueue
= Queue()
59 _fromResponderQueue
= Queue()
62 _responsesCounter
= {}
63 _config_template
= """
65 _config_params
= ['_testServerPort']
66 _acl
= ['127.0.0.1/32']
69 _healthCheckName
= 'a.root-servers.net.'
70 _healthCheckCounter
= 0
71 _answerUnexpected
= True
72 _checkConfigExpectedOutput
= None
74 _skipListeningOnCL
= False
75 _alternateListeningAddr
= None
76 _alternateListeningPort
= None
77 _backgroundThreads
= {}
80 _extraStartupSleep
= 0
83 def waitForTCPSocket(cls
, ipaddress
, port
):
84 for try_number
in range(0, 20):
86 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
88 sock
.connect((ipaddress
, port
))
91 except Exception as err
:
92 if err
.errno
!= errno
.ECONNREFUSED
:
93 print(f
'Error occurred: {try_number} {err}', file=sys
.stderr
)
95 # We assume the dnsdist instance does not listen. That's fine.
98 def startResponders(cls
):
99 print("Launching responders..")
101 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
102 cls
._UDPResponder
.setDaemon(True)
103 cls
._UDPResponder
.start()
104 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
105 cls
._TCPResponder
.setDaemon(True)
106 cls
._TCPResponder
.start()
107 cls
.waitForTCPSocket("127.0.0.1", cls
._testServerPort
);
110 def startDNSDist(cls
):
111 print("Launching dnsdist..")
112 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
113 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
115 with
open(confFile
, 'w') as conf
:
116 conf
.write("-- Autogenerated by dnsdisttests.py\n")
117 conf
.write(cls
._config
_template
% params
)
118 conf
.write("setSecurityPollSuffix('')")
120 if cls
._skipListeningOnCL
:
121 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
]
123 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
124 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
127 dnsdistcmd
.append('-v')
130 dnsdistcmd
.extend(['--acl', acl
])
131 print(' '.join(dnsdistcmd
))
133 # validate config with --check-config, which sets client=true, possibly exposing bugs.
134 testcmd
= dnsdistcmd
+ ['--check-config']
136 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
137 except subprocess
.CalledProcessError
as exc
:
138 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
139 if cls
._checkConfigExpectedOutput
is not None:
140 expectedOutput
= cls
._checkConfigExpectedOutput
142 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
143 if not cls
._verboseMode
and output
!= expectedOutput
:
144 raise AssertionError('dnsdist --check-config failed: %s' % output
)
146 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
147 with
open(logFile
, 'w') as fdLog
:
148 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
150 if cls
._alternateListeningAddr
and cls
._alternateListeningPort
:
151 cls
.waitForTCPSocket(cls
._alternateListeningAddr
, cls
._alternateListeningPort
)
153 cls
.waitForTCPSocket(cls
._dnsDistListeningAddr
, cls
._dnsDistPort
)
155 if cls
._dnsdist
.poll() is not None:
156 print(f
"\n*** startDNSDist log for {logFile} ***")
157 with
open(logFile
, 'r') as fdLog
:
159 print(f
"*** End startDNSDist log for {logFile} ***")
160 raise AssertionError('%s failed (%d)' % (dnsdistcmd
, cls
._dnsdist
.returncode
))
161 time
.sleep(cls
._extraStartupSleep
)
164 def setUpSockets(cls
):
165 print("Setting up UDP socket..")
166 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
167 cls
._sock
.settimeout(2.0)
168 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
171 def killProcess(cls
, p
):
172 # Don't try to kill it if it's already dead
173 if p
.poll() is not None:
177 for count
in range(20):
183 print("kill...", p
, file=sys
.stderr
)
187 # There is a race-condition with the poll() and
188 # kill() statements, when the process is dead on the
189 # kill(), this is fine
190 if e
.errno
!= errno
.ESRCH
:
196 cls
.startResponders()
200 print("Launching tests..")
203 def tearDownClass(cls
):
205 # tell the background threads to stop, if any
206 for backgroundThread
in cls
._backgroundThreads
:
207 cls
._backgroundThreads
[backgroundThread
] = False
208 cls
.killProcess(cls
._dnsdist
)
211 def _ResponderIncrementCounter(cls
):
212 if threading
.currentThread().name
in cls
._responsesCounter
:
213 cls
._responsesCounter
[threading
.currentThread().name
] += 1
215 cls
._responsesCounter
[threading
.currentThread().name
] = 1
218 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
220 if len(request
.question
) != 1:
221 print("Skipping query with question count %d" % (len(request
.question
)))
223 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
225 cls
._healthCheckCounter
+= 1
226 response
= dns
.message
.make_response(request
)
228 cls
._ResponderIncrementCounter
()
229 if not fromQueue
.empty():
230 toQueue
.put(request
, True, cls
._queueTimeout
)
231 response
= fromQueue
.get(True, cls
._queueTimeout
)
233 response
= copy
.copy(response
)
234 response
.id = request
.id
236 if synthesize
is not None:
237 response
= dns
.message
.make_response(request
)
238 response
.set_rcode(synthesize
)
241 if cls
._answerUnexpected
:
242 response
= dns
.message
.make_response(request
)
243 response
.set_rcode(dns
.rcode
.SERVFAIL
)
248 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
249 cls
._backgroundThreads
[threading
.get_native_id()] = True
250 # trailingDataResponse=True means "ignore trailing data".
251 # Other values are either False (meaning "raise an exception")
252 # or are interpreted as a response RCODE for queries with trailing data.
253 # callback is invoked for every -even healthcheck ones- query and should return a raw response
254 ignoreTrailing
= trailingDataResponse
is True
256 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
257 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
258 sock
.bind(("127.0.0.1", port
))
262 data
, addr
= sock
.recvfrom(4096)
263 except socket
.timeout
:
264 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
265 del cls
._backgroundThreads
[threading
.get_native_id()]
272 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
273 except dns
.message
.TrailingJunk
as e
:
274 print('trailing data exception in UDPResponder')
275 if trailingDataResponse
is False or forceRcode
is True:
277 print("UDP query with trailing data, synthesizing response")
278 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
279 forceRcode
= trailingDataResponse
283 wire
= callback(request
)
286 forceRcode
= dns
.rcode
.BADVERS
287 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
289 wire
= response
.to_wire()
294 sock
.sendto(wire
, addr
)
299 def handleTCPConnection(cls
, conn
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None):
300 ignoreTrailing
= trailingDataResponse
is True
306 (datalen
,) = struct
.unpack("!H", data
)
307 data
= conn
.recv(datalen
)
310 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
311 except dns
.message
.TrailingJunk
as e
:
312 if trailingDataResponse
is False or forceRcode
is True:
314 print("TCP query with trailing data, synthesizing response")
315 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
316 forceRcode
= trailingDataResponse
319 wire
= callback(request
)
322 forceRcode
= dns
.rcode
.BADVERS
323 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
325 wire
= response
.to_wire(max_size
=65535)
331 conn
.send(struct
.pack("!H", len(wire
)))
334 while multipleResponses
:
335 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
336 # otherwise we might read responses intended for the next connection
337 if fromQueue
.empty():
340 response
= fromQueue
.get(False)
344 response
= copy
.copy(response
)
345 response
.id = request
.id
346 wire
= response
.to_wire(max_size
=65535)
348 conn
.send(struct
.pack("!H", len(wire
)))
350 except socket
.error
as e
:
351 # some of the tests are going to close
352 # the connection on us, just deal with it
358 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, multipleConnections
=False, listeningAddr
='127.0.0.1'):
359 cls
._backgroundThreads
[threading
.get_native_id()] = True
360 # trailingDataResponse=True means "ignore trailing data".
361 # Other values are either False (meaning "raise an exception")
362 # or are interpreted as a response RCODE for queries with trailing data.
363 # callback is invoked for every -even healthcheck ones- query and should return a raw response
365 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
366 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
367 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
369 sock
.bind((listeningAddr
, port
))
370 except socket
.error
as e
:
371 print("Error binding in the TCP responder: %s" % str(e
))
377 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
381 (conn
, _
) = sock
.accept()
384 except ConnectionResetError
:
386 except socket
.timeout
:
387 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
388 del cls
._backgroundThreads
[threading
.get_native_id()]
394 if multipleConnections
:
395 thread
= threading
.Thread(name
='TCP Connection Handler',
396 target
=cls
.handleTCPConnection
,
397 args
=[conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
])
398 thread
.setDaemon(True)
401 cls
.handleTCPConnection(conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
)
406 def handleDoHConnection(cls
, config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
):
407 ignoreTrailing
= trailingDataResponse
is True
409 h2conn
= h2
.connection
.H2Connection(config
=config
)
410 h2conn
.initiate_connection()
411 conn
.sendall(h2conn
.data_to_send())
412 except ssl
.SSLEOFError
as e
:
413 print("Unexpected EOF: %s" % (e
))
419 # try to read the entire Proxy Protocol header
420 proxy
= ProxyProtocol()
421 header
= conn
.recv(proxy
.HEADER_SIZE
)
423 print('unable to get header')
427 if not proxy
.parseHeader(header
):
428 print('unable to parse header')
433 proxyContent
= conn
.recv(proxy
.contentLen
)
435 print('unable to get content')
439 payload
= header
+ proxyContent
440 toQueue
.put(payload
, True, cls
._queueTimeout
)
442 # be careful, HTTP/2 headers and data might be in different recv() results
443 requestHeaders
= None
445 data
= conn
.recv(65535)
449 events
= h2conn
.receive_data(data
)
451 if isinstance(event
, h2
.events
.RequestReceived
):
452 requestHeaders
= event
.headers
453 if isinstance(event
, h2
.events
.DataReceived
):
454 h2conn
.acknowledge_received_data(event
.flow_controlled_length
, event
.stream_id
)
455 if not event
.stream_id
in dnsData
:
456 dnsData
[event
.stream_id
] = b
''
457 dnsData
[event
.stream_id
] = dnsData
[event
.stream_id
] + (event
.data
)
458 if event
.stream_ended
:
462 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=ignoreTrailing
)
463 except dns
.message
.TrailingJunk
as e
:
464 if trailingDataResponse
is False or forceRcode
is True:
466 print("DOH query with trailing data, synthesizing response")
467 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=True)
468 forceRcode
= trailingDataResponse
471 status
, wire
= callback(request
, requestHeaders
, fromQueue
, toQueue
)
473 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
475 wire
= response
.to_wire(max_size
=65535)
483 (':status', str(status
)),
484 ('content-length', str(len(wire
))),
485 ('content-type', 'application/dns-message'),
487 h2conn
.send_headers(stream_id
=event
.stream_id
, headers
=headers
)
488 h2conn
.send_data(stream_id
=event
.stream_id
, data
=wire
, end_stream
=True)
490 data_to_send
= h2conn
.data_to_send()
492 conn
.sendall(data_to_send
)
501 def DOHResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, useProxyProtocol
=False):
502 cls
._backgroundThreads
[threading
.get_native_id()] = True
503 # trailingDataResponse=True means "ignore trailing data".
504 # Other values are either False (meaning "raise an exception")
505 # or are interpreted as a response RCODE for queries with trailing data.
506 # callback is invoked for every -even healthcheck ones- query and should return a raw response
508 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
509 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
510 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
512 sock
.bind(("127.0.0.1", port
))
513 except socket
.error
as e
:
514 print("Error binding in the TCP responder: %s" % str(e
))
520 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
522 config
= h2
.config
.H2Configuration(client_side
=False)
526 (conn
, _
) = sock
.accept()
529 except ConnectionResetError
:
531 except socket
.timeout
:
532 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
533 del cls
._backgroundThreads
[threading
.get_native_id()]
539 thread
= threading
.Thread(name
='DoH Connection Handler',
540 target
=cls
.handleDoHConnection
,
541 args
=[config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
])
542 thread
.setDaemon(True)
548 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
549 if useQueue
and response
is not None:
550 cls
._toResponderQueue
.put(response
, True, timeout
)
553 cls
._sock
.settimeout(timeout
)
557 query
= query
.to_wire()
558 cls
._sock
.send(query
)
559 data
= cls
._sock
.recv(4096)
560 except socket
.timeout
:
564 cls
._sock
.settimeout(None)
568 if useQueue
and not cls
._fromResponderQueue
.empty():
569 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
571 message
= dns
.message
.from_wire(data
)
572 return (receivedQuery
, message
)
575 def openTCPConnection(cls
, timeout
=None, port
=None):
576 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
577 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
579 sock
.settimeout(timeout
)
582 port
= cls
._dnsDistPort
584 sock
.connect(("127.0.0.1", port
))
588 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
589 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
590 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
592 sock
.settimeout(timeout
)
595 if hasattr(ssl
, 'create_default_context'):
596 sslctx
= ssl
.create_default_context(cafile
=caCert
)
597 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
599 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
601 sslsock
.connect(("127.0.0.1", port
))
605 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
607 wire
= query
.to_wire()
612 cls
._toResponderQueue
.put(response
, True, timeout
)
614 sock
.send(struct
.pack("!H", len(wire
)))
618 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
619 print("reading data")
623 (datalen
,) = struct
.unpack("!H", data
)
625 data
= sock
.recv(datalen
)
628 message
= dns
.message
.from_wire(data
)
631 if useQueue
and not cls
._fromResponderQueue
.empty():
632 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
633 print("Got from queue")
635 return (receivedQuery
, message
)
641 def sendDOTQuery(cls
, port
, serverName
, query
, response
, caFile
, useQueue
=True):
642 conn
= cls
.openTLSConnection(port
, serverName
, caFile
)
643 cls
.sendTCPQueryOverConnection(conn
, query
, response
=response
)
645 return cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
646 return None, cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
649 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
652 cls
._toResponderQueue
.put(response
, True, timeout
)
654 sock
= cls
.openTCPConnection(timeout
)
657 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
658 message
= cls
.recvTCPResponseOverConnection(sock
)
659 except socket
.timeout
as e
:
660 print("Timeout while sending or receiving TCP data: %s" % (str(e
)))
661 except socket
.error
as e
:
662 print("Network error: %s" % (str(e
)))
668 if useQueue
and not cls
._fromResponderQueue
.empty():
669 print("Got from queue")
671 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
673 print("queue is empty")
675 return (receivedQuery
, message
)
678 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
680 for response
in responses
:
681 cls
._toResponderQueue
.put(response
, True, timeout
)
682 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
683 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
685 sock
.settimeout(timeout
)
687 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
692 wire
= query
.to_wire()
696 sock
.send(struct
.pack("!H", len(wire
)))
702 (datalen
,) = struct
.unpack("!H", data
)
703 data
= sock
.recv(datalen
)
704 messages
.append(dns
.message
.from_wire(data
))
706 except socket
.timeout
as e
:
707 print("Timeout while receiving multiple TCP responses: %s" % (str(e
)))
708 except socket
.error
as e
:
709 print("Network error: %s" % (str(e
)))
714 if useQueue
and not cls
._fromResponderQueue
.empty():
715 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
716 return (receivedQuery
, messages
)
719 # This function is called before every test
721 # Clear the responses counters
722 self
._responsesCounter
.clear()
724 self
._healthCheckCounter
= 0
726 # Make sure the queues are empty, in case
727 # a previous test failed
728 self
.clearResponderQueues()
730 super(DNSDistTest
, self
).setUp()
733 def clearToResponderQueue(cls
):
734 while not cls
._toResponderQueue
.empty():
735 cls
._toResponderQueue
.get(False)
738 def clearFromResponderQueue(cls
):
739 while not cls
._fromResponderQueue
.empty():
740 cls
._fromResponderQueue
.get(False)
743 def clearResponderQueues(cls
):
744 cls
.clearToResponderQueue()
745 cls
.clearFromResponderQueue()
748 def generateConsoleKey():
749 return libnacl
.utils
.salsa_key()
752 def _encryptConsole(cls
, command
, nonce
):
753 command
= command
.encode('UTF-8')
754 if cls
._consoleKey
is None:
756 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
759 def _decryptConsole(cls
, command
, nonce
):
760 if cls
._consoleKey
is None:
763 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
764 return result
.decode('UTF-8')
767 def sendConsoleCommand(cls
, command
, timeout
=5.0):
768 ourNonce
= libnacl
.utils
.rand_nonce()
770 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
771 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
773 sock
.settimeout(timeout
)
775 sock
.connect(("127.0.0.1", cls
._consolePort
))
777 theirNonce
= sock
.recv(len(ourNonce
))
778 if len(theirNonce
) != len(ourNonce
):
779 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
780 if len(theirNonce
) == 0:
781 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
784 halfNonceSize
= int(len(ourNonce
) / 2)
785 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
786 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
787 msg
= cls
._encryptConsole
(command
, writingNonce
)
788 sock
.send(struct
.pack("!I", len(msg
)))
792 raise socket
.error("Got EOF while reading the response size")
794 (responseLen
,) = struct
.unpack("!I", data
)
795 data
= sock
.recv(responseLen
)
796 response
= cls
._decryptConsole
(data
, readingNonce
)
800 def compareOptions(self
, a
, b
):
801 self
.assertEqual(len(a
), len(b
))
802 for idx
in range(len(a
)):
803 self
.assertEqual(a
[idx
], b
[idx
])
805 def checkMessageNoEDNS(self
, expected
, received
):
806 self
.assertEqual(expected
, received
)
807 self
.assertEqual(received
.edns
, -1)
808 self
.assertEqual(len(received
.options
), 0)
810 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
811 self
.assertEqual(expected
, received
)
812 self
.assertEqual(received
.edns
, 0)
813 self
.assertEqual(expected
.payload
, received
.payload
)
815 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
816 self
.assertEqual(expected
, received
)
817 self
.assertEqual(received
.edns
, 0)
818 self
.assertEqual(expected
.payload
, received
.payload
)
819 self
.assertEqual(len(received
.options
), withCookies
)
821 for option
in received
.options
:
822 self
.assertEqual(option
.otype
, 10)
824 for option
in received
.options
:
825 self
.assertNotEqual(option
.otype
, 10)
827 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
828 self
.assertEqual(expected
, received
)
829 self
.assertEqual(received
.edns
, 0)
830 self
.assertEqual(expected
.payload
, received
.payload
)
831 self
.assertEqual(len(received
.options
), 1 + additionalOptions
)
833 for option
in received
.options
:
834 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
837 self
.assertNotEqual(additionalOptions
, 0)
839 self
.compareOptions(expected
.options
, received
.options
)
840 self
.assertTrue(hasECS
)
842 def checkMessageEDNS(self
, expected
, received
):
843 self
.assertEqual(expected
, received
)
844 self
.assertEqual(received
.edns
, 0)
845 self
.assertEqual(expected
.payload
, received
.payload
)
846 self
.assertEqual(len(expected
.options
), len(received
.options
))
847 self
.compareOptions(expected
.options
, received
.options
)
849 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
850 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
852 def checkQueryEDNS(self
, expected
, received
):
853 self
.checkMessageEDNS(expected
, received
)
855 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
856 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
858 def checkQueryEDNSWithoutECS(self
, expected
, received
):
859 self
.checkMessageEDNSWithoutECS(expected
, received
)
861 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
862 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
864 def checkQueryNoEDNS(self
, expected
, received
):
865 self
.checkMessageNoEDNS(expected
, received
)
867 def checkResponseNoEDNS(self
, expected
, received
):
868 self
.checkMessageNoEDNS(expected
, received
)
870 def generateNewCertificateAndKey(self
):
871 # generate and sign a new cert
872 cmd
= ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
875 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
876 output
= process
.communicate(input='')
877 except subprocess
.CalledProcessError
as exc
:
878 raise AssertionError('openssl req failed (%d): %s' % (exc
.returncode
, exc
.output
))
879 cmd
= ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
882 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
883 output
= process
.communicate(input='')
884 except subprocess
.CalledProcessError
as exc
:
885 raise AssertionError('openssl x509 failed (%d): %s' % (exc
.returncode
, exc
.output
))
887 with
open('server.chain', 'w') as outFile
:
888 for inFileName
in ['server.pem', 'ca.pem']:
889 with
open(inFileName
) as inFile
:
890 outFile
.write(inFile
.read())
892 cmd
= ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
895 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
896 output
= process
.communicate(input='')
897 except subprocess
.CalledProcessError
as exc
:
898 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc
.returncode
, exc
.output
))
900 def checkMessageProxyProtocol(self
, receivedProxyPayload
, source
, destination
, isTCP
, values
=[], v6
=False, sourcePort
=None, destinationPort
=None):
901 proxy
= ProxyProtocol()
902 self
.assertTrue(proxy
.parseHeader(receivedProxyPayload
))
903 self
.assertEqual(proxy
.version
, 0x02)
904 self
.assertEqual(proxy
.command
, 0x01)
906 self
.assertEqual(proxy
.family
, 0x02)
908 self
.assertEqual(proxy
.family
, 0x01)
910 self
.assertEqual(proxy
.protocol
, 0x02)
912 self
.assertEqual(proxy
.protocol
, 0x01)
913 self
.assertGreater(proxy
.contentLen
, 0)
915 self
.assertTrue(proxy
.parseAddressesAndPorts(receivedProxyPayload
))
916 self
.assertEqual(proxy
.source
, source
)
917 self
.assertEqual(proxy
.destination
, destination
)
919 self
.assertEqual(proxy
.sourcePort
, sourcePort
)
921 self
.assertEqual(proxy
.destinationPort
, destinationPort
)
923 self
.assertEqual(proxy
.destinationPort
, self
._dnsDistPort
)
925 self
.assertTrue(proxy
.parseAdditionalValues(receivedProxyPayload
))
928 self
.assertEqual(proxy
.values
, values
)
931 def getDOHGetURL(cls
, baseurl
, query
, rawQuery
=False):
935 wire
= query
.to_wire()
936 param
= base64
.urlsafe_b64encode(wire
).decode('UTF8').rstrip('=')
937 return baseurl
+ "?dns=" + param
940 def openDOHConnection(cls
, port
, caFile
, timeout
=2.0):
942 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2
)
944 conn
.setopt(pycurl
.HTTPHEADER
, ["Content-type: application/dns-message",
945 "Accept: application/dns-message"])
949 def sendDOHQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True, fromQueue
=None, toQueue
=None):
950 url
= cls
.getDOHGetURL(baseurl
, query
, rawQuery
)
951 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
952 response_headers
= BytesIO()
953 #conn.setopt(pycurl.VERBOSE, True)
954 conn
.setopt(pycurl
.URL
, url
)
955 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
957 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
958 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
960 conn
.setopt(pycurl
.CAINFO
, caFile
)
962 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
963 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
967 toQueue
.put(response
, True, timeout
)
969 cls
._toResponderQueue
.put(response
, True, timeout
)
973 cls
._response
_headers
= ''
974 data
= conn
.perform_rb()
975 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
976 if cls
._rcode
== 200 and not rawResponse
:
977 message
= dns
.message
.from_wire(data
)
983 if not fromQueue
.empty():
984 receivedQuery
= fromQueue
.get(True, timeout
)
986 if not cls
._fromResponderQueue
.empty():
987 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
989 cls
._response
_headers
= response_headers
.getvalue()
990 return (receivedQuery
, message
)
993 def sendDOHPostQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True):
995 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
996 response_headers
= BytesIO()
997 #conn.setopt(pycurl.VERBOSE, True)
998 conn
.setopt(pycurl
.URL
, url
)
999 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1001 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1002 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1004 conn
.setopt(pycurl
.CAINFO
, caFile
)
1006 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1007 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1008 conn
.setopt(pycurl
.POST
, True)
1011 data
= data
.to_wire()
1013 conn
.setopt(pycurl
.POSTFIELDS
, data
)
1016 cls
._toResponderQueue
.put(response
, True, timeout
)
1018 receivedQuery
= None
1020 cls
._response
_headers
= ''
1021 data
= conn
.perform_rb()
1022 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1023 if cls
._rcode
== 200 and not rawResponse
:
1024 message
= dns
.message
.from_wire(data
)
1028 if useQueue
and not cls
._fromResponderQueue
.empty():
1029 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1031 cls
._response
_headers
= response_headers
.getvalue()
1032 return (receivedQuery
, message
)
1034 def sendDOHQueryWrapper(self
, query
, response
, useQueue
=True):
1035 return self
.sendDOHQuery(self
._dohServerPort
, self
._serverName
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1037 def sendDOTQueryWrapper(self
, query
, response
, useQueue
=True):
1038 return self
.sendDOTQuery(self
._tlsServerPort
, self
._serverName
, query
, response
, self
._caCert
, useQueue
=useQueue
)