16 import clientsubnetoption
29 from io
import BytesIO
31 from doqclient
import quic_query
32 from doh3client
import doh3_query
34 from eqdnsmessage
import AssertEqualDNSMessageMixin
35 from proxyprotocol
import ProxyProtocol
37 # Python2/3 compatibility hacks
39 from queue
import Queue
41 from Queue
import Queue
49 if not 'PYTEST_XDIST_WORKER' in os
.environ
:
51 workerName
= os
.environ
['PYTEST_XDIST_WORKER']
52 return int(workerName
[2:])
56 def pickAvailablePort():
58 workerID
= getWorkerID()
59 if workerID
in workerPorts
:
60 port
= workerPorts
[workerID
] + 1
62 port
= 11000 + (workerID
* 1000)
63 workerPorts
[workerID
] = port
66 class DNSDistTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
75 _dnsDistListeningAddr
= "127.0.0.1"
76 _toResponderQueue
= Queue()
77 _fromResponderQueue
= Queue()
80 _responsesCounter
= {}
81 _config_template
= """
83 _config_params
= ['_testServerPort']
84 _acl
= ['127.0.0.1/32']
86 _healthCheckName
= 'a.root-servers.net.'
87 _healthCheckCounter
= 0
88 _answerUnexpected
= True
89 _checkConfigExpectedOutput
= None
91 _skipListeningOnCL
= False
92 _alternateListeningAddr
= None
93 _alternateListeningPort
= None
94 _backgroundThreads
= {}
97 _extraStartupSleep
= 0
98 _dnsDistPort
= pickAvailablePort()
99 _consolePort
= pickAvailablePort()
100 _testServerPort
= pickAvailablePort()
103 def waitForTCPSocket(cls
, ipaddress
, port
):
104 for try_number
in range(0, 20):
106 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
108 sock
.connect((ipaddress
, port
))
111 except Exception as err
:
112 if err
.errno
!= errno
.ECONNREFUSED
:
113 print(f
'Error occurred: {try_number} {err}', file=sys
.stderr
)
115 # We assume the dnsdist instance does not listen. That's fine.
118 def startResponders(cls
):
119 print("Launching responders..")
120 cls
._testServerPort
= pickAvailablePort()
122 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
123 cls
._UDPResponder
.daemon
= True
124 cls
._UDPResponder
.start()
125 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
126 cls
._TCPResponder
.daemon
= True
127 cls
._TCPResponder
.start()
128 cls
.waitForTCPSocket("127.0.0.1", cls
._testServerPort
);
131 def startDNSDist(cls
):
132 cls
._dnsDistPort
= pickAvailablePort()
133 cls
._consolePort
= pickAvailablePort()
135 print("Launching dnsdist..")
136 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
137 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
139 with
open(confFile
, 'w') as conf
:
140 conf
.write("-- Autogenerated by dnsdisttests.py\n")
141 conf
.write(f
"-- dnsdist will listen on {cls._dnsDistPort}")
142 conf
.write(cls
._config
_template
% params
)
143 conf
.write("setSecurityPollSuffix('')")
145 if cls
._skipListeningOnCL
:
146 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
]
148 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
149 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
152 dnsdistcmd
.append('-v')
155 dnsdistcmd
.extend(['--acl', acl
])
156 print(' '.join(dnsdistcmd
))
158 # validate config with --check-config, which sets client=true, possibly exposing bugs.
159 testcmd
= dnsdistcmd
+ ['--check-config']
161 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
162 except subprocess
.CalledProcessError
as exc
:
163 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
164 if cls
._checkConfigExpectedOutput
is not None:
165 expectedOutput
= cls
._checkConfigExpectedOutput
167 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
168 if not cls
._verboseMode
and output
!= expectedOutput
:
169 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output
, expectedOutput
))
171 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
172 with
open(logFile
, 'w') as fdLog
:
173 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
175 if cls
._alternateListeningAddr
and cls
._alternateListeningPort
:
176 cls
.waitForTCPSocket(cls
._alternateListeningAddr
, cls
._alternateListeningPort
)
178 cls
.waitForTCPSocket(cls
._dnsDistListeningAddr
, cls
._dnsDistPort
)
180 if cls
._dnsdist
.poll() is not None:
181 print(f
"\n*** startDNSDist log for {logFile} ***")
182 with
open(logFile
, 'r') as fdLog
:
184 print(f
"*** End startDNSDist log for {logFile} ***")
185 raise AssertionError('%s failed (%d)' % (dnsdistcmd
, cls
._dnsdist
.returncode
))
186 time
.sleep(cls
._extraStartupSleep
)
189 def setUpSockets(cls
):
190 print("Setting up UDP socket..")
191 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
192 cls
._sock
.settimeout(2.0)
193 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
196 def killProcess(cls
, p
):
197 # Don't try to kill it if it's already dead
198 if p
.poll() is not None:
202 for count
in range(20):
208 print("kill...", p
, file=sys
.stderr
)
212 # There is a race-condition with the poll() and
213 # kill() statements, when the process is dead on the
214 # kill(), this is fine
215 if e
.errno
!= errno
.ESRCH
:
221 cls
.startResponders()
225 print("Launching tests..")
228 def tearDownClass(cls
):
230 # tell the background threads to stop, if any
231 for backgroundThread
in cls
._backgroundThreads
:
232 cls
._backgroundThreads
[backgroundThread
] = False
233 cls
.killProcess(cls
._dnsdist
)
236 def _ResponderIncrementCounter(cls
):
237 if threading
.current_thread().name
in cls
._responsesCounter
:
238 cls
._responsesCounter
[threading
.current_thread().name
] += 1
240 cls
._responsesCounter
[threading
.current_thread().name
] = 1
243 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
245 if len(request
.question
) != 1:
246 print("Skipping query with question count %d" % (len(request
.question
)))
248 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
250 cls
._healthCheckCounter
+= 1
251 response
= dns
.message
.make_response(request
)
253 cls
._ResponderIncrementCounter
()
254 if not fromQueue
.empty():
255 toQueue
.put(request
, True, cls
._queueTimeout
)
256 response
= fromQueue
.get(True, cls
._queueTimeout
)
258 response
= copy
.copy(response
)
259 response
.id = request
.id
261 if synthesize
is not None:
262 response
= dns
.message
.make_response(request
)
263 response
.set_rcode(synthesize
)
266 if cls
._answerUnexpected
:
267 response
= dns
.message
.make_response(request
)
268 response
.set_rcode(dns
.rcode
.SERVFAIL
)
273 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
274 cls
._backgroundThreads
[threading
.get_native_id()] = True
275 # trailingDataResponse=True means "ignore trailing data".
276 # Other values are either False (meaning "raise an exception")
277 # or are interpreted as a response RCODE for queries with trailing data.
278 # callback is invoked for every -even healthcheck ones- query and should return a raw response
279 ignoreTrailing
= trailingDataResponse
is True
281 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
282 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
283 sock
.bind(("127.0.0.1", port
))
287 data
, addr
= sock
.recvfrom(4096)
288 except socket
.timeout
:
289 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
290 del cls
._backgroundThreads
[threading
.get_native_id()]
297 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
298 except dns
.message
.TrailingJunk
as e
:
299 print('trailing data exception in UDPResponder')
300 if trailingDataResponse
is False or forceRcode
is True:
302 print("UDP query with trailing data, synthesizing response")
303 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
304 forceRcode
= trailingDataResponse
308 wire
= callback(request
)
311 forceRcode
= dns
.rcode
.BADVERS
312 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
314 wire
= response
.to_wire()
319 sock
.sendto(wire
, addr
)
324 def handleTCPConnection(cls
, conn
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, partialWrite
=False):
325 ignoreTrailing
= trailingDataResponse
is True
328 except Exception as err
:
330 print(f
'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
335 (datalen
,) = struct
.unpack("!H", data
)
336 data
= conn
.recv(datalen
)
339 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
340 except dns
.message
.TrailingJunk
as e
:
341 if trailingDataResponse
is False or forceRcode
is True:
343 print("TCP query with trailing data, synthesizing response")
344 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
345 forceRcode
= trailingDataResponse
348 wire
= callback(request
)
351 forceRcode
= dns
.rcode
.BADVERS
352 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
354 wire
= response
.to_wire(max_size
=65535)
360 wireLen
= struct
.pack("!H", len(wire
))
363 conn
.send(bytes([b
]))
369 while multipleResponses
:
370 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
371 # otherwise we might read responses intended for the next connection
372 if fromQueue
.empty():
375 response
= fromQueue
.get(False)
379 response
= copy
.copy(response
)
380 response
.id = request
.id
381 wire
= response
.to_wire(max_size
=65535)
383 conn
.send(struct
.pack("!H", len(wire
)))
385 except socket
.error
as e
:
386 # some of the tests are going to close
387 # the connection on us, just deal with it
393 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, multipleConnections
=False, listeningAddr
='127.0.0.1', partialWrite
=False):
394 cls
._backgroundThreads
[threading
.get_native_id()] = True
395 # trailingDataResponse=True means "ignore trailing data".
396 # Other values are either False (meaning "raise an exception")
397 # or are interpreted as a response RCODE for queries with trailing data.
398 # callback is invoked for every -even healthcheck ones- query and should return a raw response
400 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
401 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
402 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
404 sock
.bind((listeningAddr
, port
))
405 except socket
.error
as e
:
406 print("Error binding in the TCP responder: %s" % str(e
))
412 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
416 (conn
, _
) = sock
.accept()
419 except ConnectionResetError
:
421 except socket
.timeout
:
422 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
423 del cls
._backgroundThreads
[threading
.get_native_id()]
429 if multipleConnections
:
430 thread
= threading
.Thread(name
='TCP Connection Handler',
431 target
=cls
.handleTCPConnection
,
432 args
=[conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
])
436 cls
.handleTCPConnection(conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
)
441 def handleDoHConnection(cls
, config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
):
442 ignoreTrailing
= trailingDataResponse
is True
444 h2conn
= h2
.connection
.H2Connection(config
=config
)
445 h2conn
.initiate_connection()
446 conn
.sendall(h2conn
.data_to_send())
447 except ssl
.SSLEOFError
as e
:
448 print("Unexpected EOF: %s" % (e
))
450 except Exception as err
:
451 print(f
'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
457 # try to read the entire Proxy Protocol header
458 proxy
= ProxyProtocol()
459 header
= conn
.recv(proxy
.HEADER_SIZE
)
461 print('unable to get header')
465 if not proxy
.parseHeader(header
):
466 print('unable to parse header')
471 proxyContent
= conn
.recv(proxy
.contentLen
)
473 print('unable to get content')
477 payload
= header
+ proxyContent
478 toQueue
.put(payload
, True, cls
._queueTimeout
)
480 # be careful, HTTP/2 headers and data might be in different recv() results
481 requestHeaders
= None
484 data
= conn
.recv(65535)
485 except Exception as err
:
487 print(f
'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
491 events
= h2conn
.receive_data(data
)
493 if isinstance(event
, h2
.events
.RequestReceived
):
494 requestHeaders
= event
.headers
495 if isinstance(event
, h2
.events
.DataReceived
):
496 h2conn
.acknowledge_received_data(event
.flow_controlled_length
, event
.stream_id
)
497 if not event
.stream_id
in dnsData
:
498 dnsData
[event
.stream_id
] = b
''
499 dnsData
[event
.stream_id
] = dnsData
[event
.stream_id
] + (event
.data
)
500 if event
.stream_ended
:
504 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=ignoreTrailing
)
505 except dns
.message
.TrailingJunk
as e
:
506 if trailingDataResponse
is False or forceRcode
is True:
508 print("DOH query with trailing data, synthesizing response")
509 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=True)
510 forceRcode
= trailingDataResponse
513 status
, wire
= callback(request
, requestHeaders
, fromQueue
, toQueue
)
515 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
517 wire
= response
.to_wire(max_size
=65535)
525 (':status', str(status
)),
526 ('content-length', str(len(wire
))),
527 ('content-type', 'application/dns-message'),
529 h2conn
.send_headers(stream_id
=event
.stream_id
, headers
=headers
)
530 h2conn
.send_data(stream_id
=event
.stream_id
, data
=wire
, end_stream
=True)
532 data_to_send
= h2conn
.data_to_send()
534 conn
.sendall(data_to_send
)
543 def DOHResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, useProxyProtocol
=False):
544 cls
._backgroundThreads
[threading
.get_native_id()] = True
545 # trailingDataResponse=True means "ignore trailing data".
546 # Other values are either False (meaning "raise an exception")
547 # or are interpreted as a response RCODE for queries with trailing data.
548 # callback is invoked for every -even healthcheck ones- query and should return a raw response
550 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
551 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
552 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
554 sock
.bind(("127.0.0.1", port
))
555 except socket
.error
as e
:
556 print("Error binding in the TCP responder: %s" % str(e
))
562 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
564 config
= h2
.config
.H2Configuration(client_side
=False)
568 (conn
, _
) = sock
.accept()
571 except ConnectionResetError
:
573 except socket
.timeout
:
574 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
575 del cls
._backgroundThreads
[threading
.get_native_id()]
581 thread
= threading
.Thread(name
='DoH Connection Handler',
582 target
=cls
.handleDoHConnection
,
583 args
=[config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
])
590 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
591 if useQueue
and response
is not None:
592 cls
._toResponderQueue
.put(response
, True, timeout
)
595 cls
._sock
.settimeout(timeout
)
599 query
= query
.to_wire()
600 cls
._sock
.send(query
)
601 data
= cls
._sock
.recv(4096)
602 except socket
.timeout
:
606 cls
._sock
.settimeout(None)
610 if useQueue
and not cls
._fromResponderQueue
.empty():
611 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
613 message
= dns
.message
.from_wire(data
)
614 return (receivedQuery
, message
)
617 def openTCPConnection(cls
, timeout
=None, port
=None):
618 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
619 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
621 sock
.settimeout(timeout
)
624 port
= cls
._dnsDistPort
626 sock
.connect(("127.0.0.1", port
))
630 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None, alpn
=[]):
631 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
632 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
634 sock
.settimeout(timeout
)
637 if hasattr(ssl
, 'create_default_context'):
638 sslctx
= ssl
.create_default_context(cafile
=caCert
)
639 if len(alpn
)> 0 and hasattr(sslctx
, 'set_alpn_protocols'):
640 sslctx
.set_alpn_protocols(alpn
)
641 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
643 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
645 sslsock
.connect(("127.0.0.1", port
))
649 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
651 wire
= query
.to_wire()
656 cls
._toResponderQueue
.put(response
, True, timeout
)
658 sock
.send(struct
.pack("!H", len(wire
)))
662 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
666 (datalen
,) = struct
.unpack("!H", data
)
668 data
= sock
.recv(datalen
)
671 message
= dns
.message
.from_wire(data
)
674 if useQueue
and not cls
._fromResponderQueue
.empty():
675 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
677 return (receivedQuery
, message
)
683 def sendDOTQuery(cls
, port
, serverName
, query
, response
, caFile
, useQueue
=True):
684 conn
= cls
.openTLSConnection(port
, serverName
, caFile
)
685 cls
.sendTCPQueryOverConnection(conn
, query
, response
=response
)
687 return cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
688 return None, cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
691 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
694 cls
._toResponderQueue
.put(response
, True, timeout
)
696 sock
= cls
.openTCPConnection(timeout
)
699 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
700 message
= cls
.recvTCPResponseOverConnection(sock
)
701 except socket
.timeout
as e
:
702 print("Timeout while sending or receiving TCP data: %s" % (str(e
)))
703 except socket
.error
as e
:
704 print("Network error: %s" % (str(e
)))
710 if useQueue
and not cls
._fromResponderQueue
.empty():
712 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
714 print("queue is empty")
716 return (receivedQuery
, message
)
719 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
721 for response
in responses
:
722 cls
._toResponderQueue
.put(response
, True, timeout
)
723 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
724 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
726 sock
.settimeout(timeout
)
728 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
733 wire
= query
.to_wire()
737 sock
.send(struct
.pack("!H", len(wire
)))
743 (datalen
,) = struct
.unpack("!H", data
)
744 data
= sock
.recv(datalen
)
745 messages
.append(dns
.message
.from_wire(data
))
747 except socket
.timeout
as e
:
748 print("Timeout while receiving multiple TCP responses: %s" % (str(e
)))
749 except socket
.error
as e
:
750 print("Network error: %s" % (str(e
)))
755 if useQueue
and not cls
._fromResponderQueue
.empty():
756 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
757 return (receivedQuery
, messages
)
760 # This function is called before every test
762 # Clear the responses counters
763 self
._responsesCounter
.clear()
765 self
._healthCheckCounter
= 0
767 # Make sure the queues are empty, in case
768 # a previous test failed
769 self
.clearResponderQueues()
771 super(DNSDistTest
, self
).setUp()
774 def clearToResponderQueue(cls
):
775 while not cls
._toResponderQueue
.empty():
776 cls
._toResponderQueue
.get(False)
779 def clearFromResponderQueue(cls
):
780 while not cls
._fromResponderQueue
.empty():
781 cls
._fromResponderQueue
.get(False)
784 def clearResponderQueues(cls
):
785 cls
.clearToResponderQueue()
786 cls
.clearFromResponderQueue()
789 def generateConsoleKey():
790 return libnacl
.utils
.salsa_key()
793 def _encryptConsole(cls
, command
, nonce
):
794 command
= command
.encode('UTF-8')
795 if cls
._consoleKey
is None:
797 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
800 def _decryptConsole(cls
, command
, nonce
):
801 if cls
._consoleKey
is None:
804 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
805 return result
.decode('UTF-8')
808 def sendConsoleCommand(cls
, command
, timeout
=5.0):
809 ourNonce
= libnacl
.utils
.rand_nonce()
811 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
812 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
814 sock
.settimeout(timeout
)
816 sock
.connect(("127.0.0.1", cls
._consolePort
))
818 theirNonce
= sock
.recv(len(ourNonce
))
819 if len(theirNonce
) != len(ourNonce
):
820 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
821 if len(theirNonce
) == 0:
822 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
825 halfNonceSize
= int(len(ourNonce
) / 2)
826 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
827 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
828 msg
= cls
._encryptConsole
(command
, writingNonce
)
829 sock
.send(struct
.pack("!I", len(msg
)))
833 raise socket
.error("Got EOF while reading the response size")
835 (responseLen
,) = struct
.unpack("!I", data
)
836 data
= sock
.recv(responseLen
)
837 response
= cls
._decryptConsole
(data
, readingNonce
)
841 def compareOptions(self
, a
, b
):
842 self
.assertEqual(len(a
), len(b
))
843 for idx
in range(len(a
)):
844 self
.assertEqual(a
[idx
], b
[idx
])
846 def checkMessageNoEDNS(self
, expected
, received
):
847 self
.assertEqual(expected
, received
)
848 self
.assertEqual(received
.edns
, -1)
849 self
.assertEqual(len(received
.options
), 0)
851 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
852 self
.assertEqual(expected
, received
)
853 self
.assertEqual(received
.edns
, 0)
854 self
.assertEqual(expected
.payload
, received
.payload
)
856 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
857 self
.assertEqual(expected
, received
)
858 self
.assertEqual(received
.edns
, 0)
859 self
.assertEqual(expected
.payload
, received
.payload
)
860 self
.assertEqual(len(received
.options
), withCookies
)
862 for option
in received
.options
:
863 self
.assertEqual(option
.otype
, 10)
865 for option
in received
.options
:
866 self
.assertNotEqual(option
.otype
, 10)
868 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
869 self
.assertEqual(expected
, received
)
870 self
.assertEqual(received
.edns
, 0)
871 self
.assertEqual(expected
.payload
, received
.payload
)
872 self
.assertEqual(len(received
.options
), 1 + additionalOptions
)
874 for option
in received
.options
:
875 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
878 self
.assertNotEqual(additionalOptions
, 0)
880 self
.compareOptions(expected
.options
, received
.options
)
881 self
.assertTrue(hasECS
)
883 def checkMessageEDNS(self
, expected
, received
):
884 self
.assertEqual(expected
, received
)
885 self
.assertEqual(received
.edns
, 0)
886 self
.assertEqual(expected
.payload
, received
.payload
)
887 self
.assertEqual(len(expected
.options
), len(received
.options
))
888 self
.compareOptions(expected
.options
, received
.options
)
890 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
891 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
893 def checkQueryEDNS(self
, expected
, received
):
894 self
.checkMessageEDNS(expected
, received
)
896 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
897 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
899 def checkQueryEDNSWithoutECS(self
, expected
, received
):
900 self
.checkMessageEDNSWithoutECS(expected
, received
)
902 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
903 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
905 def checkQueryNoEDNS(self
, expected
, received
):
906 self
.checkMessageNoEDNS(expected
, received
)
908 def checkResponseNoEDNS(self
, expected
, received
):
909 self
.checkMessageNoEDNS(expected
, received
)
912 def generateNewCertificateAndKey(filePrefix
):
913 # generate and sign a new cert
914 cmd
= ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix
+ '.key', '-out', filePrefix
+ '.csr', '-config', 'configServer.conf']
917 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
918 output
= process
.communicate(input='')
919 except subprocess
.CalledProcessError
as exc
:
920 raise AssertionError('openssl req failed (%d): %s' % (exc
.returncode
, exc
.output
))
921 cmd
= ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix
+ '.csr', '-out', filePrefix
+ '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
924 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
925 output
= process
.communicate(input='')
926 except subprocess
.CalledProcessError
as exc
:
927 raise AssertionError('openssl x509 failed (%d): %s' % (exc
.returncode
, exc
.output
))
929 with
open(filePrefix
+ '.chain', 'w') as outFile
:
930 for inFileName
in [filePrefix
+ '.pem', 'ca.pem']:
931 with
open(inFileName
) as inFile
:
932 outFile
.write(inFile
.read())
934 cmd
= ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix
+ '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix
+ '.key', '-out', filePrefix
+ '.p12']
937 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
938 output
= process
.communicate(input='')
939 except subprocess
.CalledProcessError
as exc
:
940 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc
.returncode
, exc
.output
))
942 def checkMessageProxyProtocol(self
, receivedProxyPayload
, source
, destination
, isTCP
, values
=[], v6
=False, sourcePort
=None, destinationPort
=None):
943 proxy
= ProxyProtocol()
944 self
.assertTrue(proxy
.parseHeader(receivedProxyPayload
))
945 self
.assertEqual(proxy
.version
, 0x02)
946 self
.assertEqual(proxy
.command
, 0x01)
948 self
.assertEqual(proxy
.family
, 0x02)
950 self
.assertEqual(proxy
.family
, 0x01)
952 self
.assertEqual(proxy
.protocol
, 0x02)
954 self
.assertEqual(proxy
.protocol
, 0x01)
955 self
.assertGreater(proxy
.contentLen
, 0)
957 self
.assertTrue(proxy
.parseAddressesAndPorts(receivedProxyPayload
))
958 self
.assertEqual(proxy
.source
, source
)
959 self
.assertEqual(proxy
.destination
, destination
)
961 self
.assertEqual(proxy
.sourcePort
, sourcePort
)
963 self
.assertEqual(proxy
.destinationPort
, destinationPort
)
965 self
.assertEqual(proxy
.destinationPort
, self
._dnsDistPort
)
967 self
.assertTrue(proxy
.parseAdditionalValues(receivedProxyPayload
))
970 self
.assertEqual(proxy
.values
, values
)
973 def getDOHGetURL(cls
, baseurl
, query
, rawQuery
=False):
977 wire
= query
.to_wire()
978 param
= base64
.urlsafe_b64encode(wire
).decode('UTF8').rstrip('=')
979 return baseurl
+ "?dns=" + param
982 def openDOHConnection(cls
, port
, caFile
, timeout
=2.0):
984 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2
)
986 conn
.setopt(pycurl
.HTTPHEADER
, ["Content-type: application/dns-message",
987 "Accept: application/dns-message"])
991 def sendDOHQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True, fromQueue
=None, toQueue
=None, conn
=None):
992 url
= cls
.getDOHGetURL(baseurl
, query
, rawQuery
)
995 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
996 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
997 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1000 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1001 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1003 conn
.setopt(pycurl
.CAINFO
, caFile
)
1005 response_headers
= BytesIO()
1006 #conn.setopt(pycurl.VERBOSE, True)
1007 conn
.setopt(pycurl
.URL
, url
)
1008 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1010 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1011 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1015 toQueue
.put(response
, True, timeout
)
1017 cls
._toResponderQueue
.put(response
, True, timeout
)
1019 receivedQuery
= None
1021 cls
._response
_headers
= ''
1022 data
= conn
.perform_rb()
1023 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1024 if cls
._rcode
== 200 and not rawResponse
:
1025 message
= dns
.message
.from_wire(data
)
1031 if not fromQueue
.empty():
1032 receivedQuery
= fromQueue
.get(True, timeout
)
1034 if not cls
._fromResponderQueue
.empty():
1035 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1037 cls
._response
_headers
= response_headers
.getvalue()
1038 return (receivedQuery
, message
)
1041 def sendDOHPostQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True):
1043 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
1044 response_headers
= BytesIO()
1045 #conn.setopt(pycurl.VERBOSE, True)
1046 conn
.setopt(pycurl
.URL
, url
)
1047 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1048 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1049 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1051 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1052 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1054 conn
.setopt(pycurl
.CAINFO
, caFile
)
1056 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1057 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1058 conn
.setopt(pycurl
.POST
, True)
1061 data
= data
.to_wire()
1063 conn
.setopt(pycurl
.POSTFIELDS
, data
)
1066 cls
._toResponderQueue
.put(response
, True, timeout
)
1068 receivedQuery
= None
1070 cls
._response
_headers
= ''
1071 data
= conn
.perform_rb()
1072 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1073 if cls
._rcode
== 200 and not rawResponse
:
1074 message
= dns
.message
.from_wire(data
)
1078 if useQueue
and not cls
._fromResponderQueue
.empty():
1079 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1081 cls
._response
_headers
= response_headers
.getvalue()
1082 return (receivedQuery
, message
)
1084 def sendDOHQueryWrapper(self
, query
, response
, useQueue
=True):
1085 return self
.sendDOHQuery(self
._dohServerPort
, self
._serverName
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1087 def sendDOHWithNGHTTP2QueryWrapper(self
, query
, response
, useQueue
=True):
1088 return self
.sendDOHQuery(self
._dohWithNGHTTP
2ServerPort
, self
._serverName
, self
._dohWithNGHTTP
2BaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1090 def sendDOHWithH2OQueryWrapper(self
, query
, response
, useQueue
=True):
1091 return self
.sendDOHQuery(self
._dohWithH
2OServerPort
, self
._serverName
, self
._dohWithH
2OBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1093 def sendDOTQueryWrapper(self
, query
, response
, useQueue
=True):
1094 return self
.sendDOTQuery(self
._tlsServerPort
, self
._serverName
, query
, response
, self
._caCert
, useQueue
=useQueue
)
1096 def sendDOQQueryWrapper(self
, query
, response
, useQueue
=True):
1097 return self
.sendDOQQuery(self
._doqServerPort
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1099 def sendDOH3QueryWrapper(self
, query
, response
, useQueue
=True):
1100 return self
.sendDOH3Query(self
._doh
3ServerPort
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1102 def getDOQConnection(cls
, port
, caFile
=None, source
=None, source_port
=0):
1104 manager
= dns
.quic
.SyncQuicManager(
1108 return manager
.connect('127.0.0.1', port
, source
, source_port
)
1111 def sendDOQQuery(cls
, port
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None):
1115 toQueue
.put(response
, True, timeout
)
1117 cls
._toResponderQueue
.put(response
, True, timeout
)
1119 (message
, _
) = quic_query(query
, '127.0.0.1', timeout
, port
, verify
=caFile
, server_hostname
=serverName
)
1121 receivedQuery
= None
1125 if not fromQueue
.empty():
1126 receivedQuery
= fromQueue
.get(True, timeout
)
1128 if not cls
._fromResponderQueue
.empty():
1129 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1131 return (receivedQuery
, message
)
1134 def sendDOH3Query(cls
, port
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None, post
=False):
1138 toQueue
.put(response
, True, timeout
)
1140 cls
._toResponderQueue
.put(response
, True, timeout
)
1142 message
= doh3_query(query
, baseurl
, timeout
, port
, verify
=caFile
, server_hostname
=serverName
, post
=post
)
1144 receivedQuery
= None
1148 if not fromQueue
.empty():
1149 receivedQuery
= fromQueue
.get(True, timeout
)
1151 if not cls
._fromResponderQueue
.empty():
1152 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1154 return (receivedQuery
, message
)