16 import clientsubnetoption
29 from io
import BytesIO
31 from doqclient
import quic_query
32 from doh3client
import doh3_query
34 from eqdnsmessage
import AssertEqualDNSMessageMixin
35 from proxyprotocol
import ProxyProtocol
37 # Python2/3 compatibility hacks
39 from queue
import Queue
41 from Queue
import Queue
49 if not 'PYTEST_XDIST_WORKER' in os
.environ
:
51 workerName
= os
.environ
['PYTEST_XDIST_WORKER']
52 return int(workerName
[2:])
56 def pickAvailablePort():
58 workerID
= getWorkerID()
59 if workerID
in workerPorts
:
60 port
= workerPorts
[workerID
] + 1
62 port
= 11000 + (workerID
* 1000)
63 workerPorts
[workerID
] = port
66 class DNSDistTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
75 _dnsDistListeningAddr
= "127.0.0.1"
76 _toResponderQueue
= Queue()
77 _fromResponderQueue
= Queue()
80 _responsesCounter
= {}
81 _config_template
= """
83 _config_params
= ['_testServerPort']
84 _acl
= ['127.0.0.1/32']
86 _healthCheckName
= 'a.root-servers.net.'
87 _healthCheckCounter
= 0
88 _answerUnexpected
= True
89 _checkConfigExpectedOutput
= None
92 _skipListeningOnCL
= False
93 _alternateListeningAddr
= None
94 _alternateListeningPort
= None
95 _backgroundThreads
= {}
98 _extraStartupSleep
= 0
99 _dnsDistPort
= pickAvailablePort()
100 _consolePort
= pickAvailablePort()
101 _testServerPort
= pickAvailablePort()
104 def waitForTCPSocket(cls
, ipaddress
, port
):
105 for try_number
in range(0, 20):
107 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
109 sock
.connect((ipaddress
, port
))
112 except Exception as err
:
113 if err
.errno
!= errno
.ECONNREFUSED
:
114 print(f
'Error occurred: {try_number} {err}', file=sys
.stderr
)
116 # We assume the dnsdist instance does not listen. That's fine.
119 def startResponders(cls
):
120 print("Launching responders..")
121 cls
._testServerPort
= pickAvailablePort()
123 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
124 cls
._UDPResponder
.daemon
= True
125 cls
._UDPResponder
.start()
126 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
127 cls
._TCPResponder
.daemon
= True
128 cls
._TCPResponder
.start()
129 cls
.waitForTCPSocket("127.0.0.1", cls
._testServerPort
);
132 def startDNSDist(cls
):
133 cls
._dnsDistPort
= pickAvailablePort()
134 cls
._consolePort
= pickAvailablePort()
136 print("Launching dnsdist..")
137 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
138 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
140 with
open(confFile
, 'w') as conf
:
141 conf
.write("-- Autogenerated by dnsdisttests.py\n")
142 conf
.write(f
"-- dnsdist will listen on {cls._dnsDistPort}")
143 conf
.write(cls
._config
_template
% params
)
144 conf
.write("setSecurityPollSuffix('')")
146 if cls
._skipListeningOnCL
:
147 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
]
149 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
150 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
153 dnsdistcmd
.append('-v')
155 preserve_env_values
= ['LD_LIBRARY_PATH', 'LLVM_PROFILE_FILE']
156 for value
in preserve_env_values
:
157 if value
in os
.environ
:
158 dnsdistcmd
.insert(0, value
+ '=' + os
.environ
[value
])
159 dnsdistcmd
.insert(0, 'sudo')
162 dnsdistcmd
.extend(['--acl', acl
])
163 print(' '.join(dnsdistcmd
))
165 # validate config with --check-config, which sets client=true, possibly exposing bugs.
166 testcmd
= dnsdistcmd
+ ['--check-config']
168 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
169 except subprocess
.CalledProcessError
as exc
:
170 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
171 if cls
._checkConfigExpectedOutput
is not None:
172 expectedOutput
= cls
._checkConfigExpectedOutput
174 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
175 if not cls
._verboseMode
and output
!= expectedOutput
:
176 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output
, expectedOutput
))
178 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
179 with
open(logFile
, 'w') as fdLog
:
180 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
182 if cls
._alternateListeningAddr
and cls
._alternateListeningPort
:
183 cls
.waitForTCPSocket(cls
._alternateListeningAddr
, cls
._alternateListeningPort
)
185 cls
.waitForTCPSocket(cls
._dnsDistListeningAddr
, cls
._dnsDistPort
)
187 if cls
._dnsdist
.poll() is not None:
188 print(f
"\n*** startDNSDist log for {logFile} ***")
189 with
open(logFile
, 'r') as fdLog
:
191 print(f
"*** End startDNSDist log for {logFile} ***")
192 raise AssertionError('%s failed (%d)' % (dnsdistcmd
, cls
._dnsdist
.returncode
))
193 time
.sleep(cls
._extraStartupSleep
)
196 def setUpSockets(cls
):
197 print("Setting up UDP socket..")
198 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
199 cls
._sock
.settimeout(2.0)
200 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
203 def killProcess(cls
, p
):
204 # Don't try to kill it if it's already dead
205 if p
.poll() is not None:
209 for count
in range(20):
215 print("kill...", p
, file=sys
.stderr
)
219 # There is a race-condition with the poll() and
220 # kill() statements, when the process is dead on the
221 # kill(), this is fine
222 if e
.errno
!= errno
.ESRCH
:
228 cls
.startResponders()
232 print("Launching tests..")
235 def tearDownClass(cls
):
237 # tell the background threads to stop, if any
238 for backgroundThread
in cls
._backgroundThreads
:
239 cls
._backgroundThreads
[backgroundThread
] = False
240 cls
.killProcess(cls
._dnsdist
)
243 def _ResponderIncrementCounter(cls
):
244 if threading
.current_thread().name
in cls
._responsesCounter
:
245 cls
._responsesCounter
[threading
.current_thread().name
] += 1
247 cls
._responsesCounter
[threading
.current_thread().name
] = 1
250 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
252 if len(request
.question
) != 1:
253 print("Skipping query with question count %d" % (len(request
.question
)))
255 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
257 cls
._healthCheckCounter
+= 1
258 response
= dns
.message
.make_response(request
)
260 cls
._ResponderIncrementCounter
()
261 if not fromQueue
.empty():
262 toQueue
.put(request
, True, cls
._queueTimeout
)
263 response
= fromQueue
.get(True, cls
._queueTimeout
)
265 response
= copy
.copy(response
)
266 response
.id = request
.id
268 if synthesize
is not None:
269 response
= dns
.message
.make_response(request
)
270 response
.set_rcode(synthesize
)
273 if cls
._answerUnexpected
:
274 response
= dns
.message
.make_response(request
)
275 response
.set_rcode(dns
.rcode
.SERVFAIL
)
280 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
281 cls
._backgroundThreads
[threading
.get_native_id()] = True
282 # trailingDataResponse=True means "ignore trailing data".
283 # Other values are either False (meaning "raise an exception")
284 # or are interpreted as a response RCODE for queries with trailing data.
285 # callback is invoked for every -even healthcheck ones- query and should return a raw response
286 ignoreTrailing
= trailingDataResponse
is True
288 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
289 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
290 sock
.bind(("127.0.0.1", port
))
294 data
, addr
= sock
.recvfrom(4096)
295 except socket
.timeout
:
296 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
297 del cls
._backgroundThreads
[threading
.get_native_id()]
304 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
305 except dns
.message
.TrailingJunk
as e
:
306 print('trailing data exception in UDPResponder')
307 if trailingDataResponse
is False or forceRcode
is True:
309 print("UDP query with trailing data, synthesizing response")
310 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
311 forceRcode
= trailingDataResponse
315 wire
= callback(request
)
318 forceRcode
= dns
.rcode
.BADVERS
319 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
321 wire
= response
.to_wire()
326 sock
.sendto(wire
, addr
)
331 def handleTCPConnection(cls
, conn
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, partialWrite
=False):
332 ignoreTrailing
= trailingDataResponse
is True
335 except Exception as err
:
337 print(f
'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
342 (datalen
,) = struct
.unpack("!H", data
)
343 data
= conn
.recv(datalen
)
346 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
347 except dns
.message
.TrailingJunk
as e
:
348 if trailingDataResponse
is False or forceRcode
is True:
350 print("TCP query with trailing data, synthesizing response")
351 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
352 forceRcode
= trailingDataResponse
355 wire
= callback(request
)
358 forceRcode
= dns
.rcode
.BADVERS
359 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
361 wire
= response
.to_wire(max_size
=65535)
367 wireLen
= struct
.pack("!H", len(wire
))
370 conn
.send(bytes([b
]))
376 while multipleResponses
:
377 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
378 # otherwise we might read responses intended for the next connection
379 if fromQueue
.empty():
382 response
= fromQueue
.get(False)
386 response
= copy
.copy(response
)
387 response
.id = request
.id
388 wire
= response
.to_wire(max_size
=65535)
390 conn
.send(struct
.pack("!H", len(wire
)))
392 except socket
.error
as e
:
393 # some of the tests are going to close
394 # the connection on us, just deal with it
400 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, multipleConnections
=False, listeningAddr
='127.0.0.1', partialWrite
=False):
401 cls
._backgroundThreads
[threading
.get_native_id()] = True
402 # trailingDataResponse=True means "ignore trailing data".
403 # Other values are either False (meaning "raise an exception")
404 # or are interpreted as a response RCODE for queries with trailing data.
405 # callback is invoked for every -even healthcheck ones- query and should return a raw response
407 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
408 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
409 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
411 sock
.bind((listeningAddr
, port
))
412 except socket
.error
as e
:
413 print("Error binding in the TCP responder: %s" % str(e
))
419 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
423 (conn
, _
) = sock
.accept()
426 except ConnectionResetError
:
428 except socket
.timeout
:
429 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
430 del cls
._backgroundThreads
[threading
.get_native_id()]
436 if multipleConnections
:
437 thread
= threading
.Thread(name
='TCP Connection Handler',
438 target
=cls
.handleTCPConnection
,
439 args
=[conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
])
443 cls
.handleTCPConnection(conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
)
448 def handleDoHConnection(cls
, config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
):
449 ignoreTrailing
= trailingDataResponse
is True
451 h2conn
= h2
.connection
.H2Connection(config
=config
)
452 h2conn
.initiate_connection()
453 conn
.sendall(h2conn
.data_to_send())
454 except ssl
.SSLEOFError
as e
:
455 print("Unexpected EOF: %s" % (e
))
457 except Exception as err
:
458 print(f
'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
464 # try to read the entire Proxy Protocol header
465 proxy
= ProxyProtocol()
466 header
= conn
.recv(proxy
.HEADER_SIZE
)
468 print('unable to get header')
472 if not proxy
.parseHeader(header
):
473 print('unable to parse header')
478 proxyContent
= conn
.recv(proxy
.contentLen
)
480 print('unable to get content')
484 payload
= header
+ proxyContent
485 toQueue
.put(payload
, True, cls
._queueTimeout
)
487 # be careful, HTTP/2 headers and data might be in different recv() results
488 requestHeaders
= None
491 data
= conn
.recv(65535)
492 except Exception as err
:
494 print(f
'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
498 events
= h2conn
.receive_data(data
)
500 if isinstance(event
, h2
.events
.RequestReceived
):
501 requestHeaders
= event
.headers
502 if isinstance(event
, h2
.events
.DataReceived
):
503 h2conn
.acknowledge_received_data(event
.flow_controlled_length
, event
.stream_id
)
504 if not event
.stream_id
in dnsData
:
505 dnsData
[event
.stream_id
] = b
''
506 dnsData
[event
.stream_id
] = dnsData
[event
.stream_id
] + (event
.data
)
507 if event
.stream_ended
:
511 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=ignoreTrailing
)
512 except dns
.message
.TrailingJunk
as e
:
513 if trailingDataResponse
is False or forceRcode
is True:
515 print("DOH query with trailing data, synthesizing response")
516 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=True)
517 forceRcode
= trailingDataResponse
520 status
, wire
= callback(request
, requestHeaders
, fromQueue
, toQueue
)
522 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
524 wire
= response
.to_wire(max_size
=65535)
532 (':status', str(status
)),
533 ('content-length', str(len(wire
))),
534 ('content-type', 'application/dns-message'),
536 h2conn
.send_headers(stream_id
=event
.stream_id
, headers
=headers
)
537 h2conn
.send_data(stream_id
=event
.stream_id
, data
=wire
, end_stream
=True)
539 data_to_send
= h2conn
.data_to_send()
541 conn
.sendall(data_to_send
)
550 def DOHResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, useProxyProtocol
=False):
551 cls
._backgroundThreads
[threading
.get_native_id()] = True
552 # trailingDataResponse=True means "ignore trailing data".
553 # Other values are either False (meaning "raise an exception")
554 # or are interpreted as a response RCODE for queries with trailing data.
555 # callback is invoked for every -even healthcheck ones- query and should return a raw response
557 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
558 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
559 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
561 sock
.bind(("127.0.0.1", port
))
562 except socket
.error
as e
:
563 print("Error binding in the TCP responder: %s" % str(e
))
569 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
571 config
= h2
.config
.H2Configuration(client_side
=False)
575 (conn
, _
) = sock
.accept()
578 except ConnectionResetError
:
580 except socket
.timeout
:
581 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
582 del cls
._backgroundThreads
[threading
.get_native_id()]
588 thread
= threading
.Thread(name
='DoH Connection Handler',
589 target
=cls
.handleDoHConnection
,
590 args
=[config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
])
597 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
598 if useQueue
and response
is not None:
599 cls
._toResponderQueue
.put(response
, True, timeout
)
602 cls
._sock
.settimeout(timeout
)
606 query
= query
.to_wire()
607 cls
._sock
.send(query
)
608 data
= cls
._sock
.recv(4096)
609 except socket
.timeout
:
613 cls
._sock
.settimeout(None)
617 if useQueue
and not cls
._fromResponderQueue
.empty():
618 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
620 message
= dns
.message
.from_wire(data
)
621 return (receivedQuery
, message
)
624 def openTCPConnection(cls
, timeout
=None, port
=None):
625 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
626 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
628 sock
.settimeout(timeout
)
631 port
= cls
._dnsDistPort
633 sock
.connect(("127.0.0.1", port
))
637 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None, alpn
=[]):
638 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
639 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
641 sock
.settimeout(timeout
)
644 if hasattr(ssl
, 'create_default_context'):
645 sslctx
= ssl
.create_default_context(cafile
=caCert
)
646 if len(alpn
)> 0 and hasattr(sslctx
, 'set_alpn_protocols'):
647 sslctx
.set_alpn_protocols(alpn
)
648 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
650 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
652 sslsock
.connect(("127.0.0.1", port
))
656 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
658 wire
= query
.to_wire()
663 cls
._toResponderQueue
.put(response
, True, timeout
)
665 sock
.send(struct
.pack("!H", len(wire
)))
669 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
673 (datalen
,) = struct
.unpack("!H", data
)
675 data
= sock
.recv(datalen
)
678 message
= dns
.message
.from_wire(data
)
681 if useQueue
and not cls
._fromResponderQueue
.empty():
682 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
684 return (receivedQuery
, message
)
690 def sendDOTQuery(cls
, port
, serverName
, query
, response
, caFile
, useQueue
=True):
691 conn
= cls
.openTLSConnection(port
, serverName
, caFile
)
692 cls
.sendTCPQueryOverConnection(conn
, query
, response
=response
)
694 return cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
695 return None, cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
698 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
701 cls
._toResponderQueue
.put(response
, True, timeout
)
704 sock
= cls
.openTCPConnection(timeout
)
705 except socket
.timeout
as e
:
706 print("Timeout while opening TCP connection: %s" % (str(e
)))
710 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
, timeout
=timeout
)
711 message
= cls
.recvTCPResponseOverConnection(sock
, timeout
=timeout
)
712 except socket
.timeout
as e
:
713 print("Timeout while sending or receiving TCP data: %s" % (str(e
)))
714 except socket
.error
as e
:
715 print("Network error: %s" % (str(e
)))
721 if useQueue
and not cls
._fromResponderQueue
.empty():
723 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
725 print("queue is empty")
727 return (receivedQuery
, message
)
730 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
732 for response
in responses
:
733 cls
._toResponderQueue
.put(response
, True, timeout
)
734 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
735 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
737 sock
.settimeout(timeout
)
739 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
744 wire
= query
.to_wire()
748 sock
.send(struct
.pack("!H", len(wire
)))
754 (datalen
,) = struct
.unpack("!H", data
)
755 data
= sock
.recv(datalen
)
756 messages
.append(dns
.message
.from_wire(data
))
758 except socket
.timeout
as e
:
759 print("Timeout while receiving multiple TCP responses: %s" % (str(e
)))
760 except socket
.error
as e
:
761 print("Network error: %s" % (str(e
)))
766 if useQueue
and not cls
._fromResponderQueue
.empty():
767 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
768 return (receivedQuery
, messages
)
771 # This function is called before every test
773 # Clear the responses counters
774 self
._responsesCounter
.clear()
776 self
._healthCheckCounter
= 0
778 # Make sure the queues are empty, in case
779 # a previous test failed
780 self
.clearResponderQueues()
782 super(DNSDistTest
, self
).setUp()
785 def clearToResponderQueue(cls
):
786 while not cls
._toResponderQueue
.empty():
787 cls
._toResponderQueue
.get(False)
790 def clearFromResponderQueue(cls
):
791 while not cls
._fromResponderQueue
.empty():
792 cls
._fromResponderQueue
.get(False)
795 def clearResponderQueues(cls
):
796 cls
.clearToResponderQueue()
797 cls
.clearFromResponderQueue()
800 def generateConsoleKey():
801 return libnacl
.utils
.salsa_key()
804 def _encryptConsole(cls
, command
, nonce
):
805 command
= command
.encode('UTF-8')
806 if cls
._consoleKey
is None:
808 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
811 def _decryptConsole(cls
, command
, nonce
):
812 if cls
._consoleKey
is None:
815 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
816 return result
.decode('UTF-8')
819 def sendConsoleCommand(cls
, command
, timeout
=5.0, IPv6
=False):
820 ourNonce
= libnacl
.utils
.rand_nonce()
822 sock
= socket
.socket(socket
.AF_INET
if not IPv6
else socket
.AF_INET6
, socket
.SOCK_STREAM
)
823 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
825 sock
.settimeout(timeout
)
827 sock
.connect(("::1", cls
._consolePort
, 0, 0) if IPv6
else ("127.0.0.1", cls
._consolePort
))
829 theirNonce
= sock
.recv(len(ourNonce
))
830 if len(theirNonce
) != len(ourNonce
):
831 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
832 if len(theirNonce
) == 0:
833 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
836 halfNonceSize
= int(len(ourNonce
) / 2)
837 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
838 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
839 msg
= cls
._encryptConsole
(command
, writingNonce
)
840 sock
.send(struct
.pack("!I", len(msg
)))
844 raise socket
.error("Got EOF while reading the response size")
846 (responseLen
,) = struct
.unpack("!I", data
)
847 data
= sock
.recv(responseLen
)
848 response
= cls
._decryptConsole
(data
, readingNonce
)
852 def compareOptions(self
, a
, b
):
853 self
.assertEqual(len(a
), len(b
))
854 for idx
in range(len(a
)):
855 self
.assertEqual(a
[idx
], b
[idx
])
857 def checkMessageNoEDNS(self
, expected
, received
):
858 self
.assertEqual(expected
, received
)
859 self
.assertEqual(received
.edns
, -1)
860 self
.assertEqual(len(received
.options
), 0)
862 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
863 self
.assertEqual(expected
, received
)
864 self
.assertEqual(received
.edns
, 0)
865 self
.assertEqual(expected
.payload
, received
.payload
)
867 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
868 self
.assertEqual(expected
, received
)
869 self
.assertEqual(received
.edns
, 0)
870 self
.assertEqual(expected
.payload
, received
.payload
)
871 self
.assertEqual(len(received
.options
), withCookies
)
873 for option
in received
.options
:
874 self
.assertEqual(option
.otype
, 10)
876 for option
in received
.options
:
877 self
.assertNotEqual(option
.otype
, 10)
879 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
880 self
.assertEqual(expected
, received
)
881 self
.assertEqual(received
.edns
, 0)
882 self
.assertEqual(expected
.payload
, received
.payload
)
883 self
.assertEqual(len(received
.options
), 1 + additionalOptions
)
885 for option
in received
.options
:
886 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
889 self
.assertNotEqual(additionalOptions
, 0)
891 self
.compareOptions(expected
.options
, received
.options
)
892 self
.assertTrue(hasECS
)
894 def checkMessageEDNS(self
, expected
, received
):
895 self
.assertEqual(expected
, received
)
896 self
.assertEqual(received
.edns
, 0)
897 self
.assertEqual(expected
.payload
, received
.payload
)
898 self
.assertEqual(len(expected
.options
), len(received
.options
))
899 self
.compareOptions(expected
.options
, received
.options
)
901 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
902 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
904 def checkQueryEDNS(self
, expected
, received
):
905 self
.checkMessageEDNS(expected
, received
)
907 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
908 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
910 def checkQueryEDNSWithoutECS(self
, expected
, received
):
911 self
.checkMessageEDNSWithoutECS(expected
, received
)
913 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
914 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
916 def checkQueryNoEDNS(self
, expected
, received
):
917 self
.checkMessageNoEDNS(expected
, received
)
919 def checkResponseNoEDNS(self
, expected
, received
):
920 self
.checkMessageNoEDNS(expected
, received
)
923 def generateNewCertificateAndKey(filePrefix
):
924 # generate and sign a new cert
925 cmd
= ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix
+ '.key', '-out', filePrefix
+ '.csr', '-config', 'configServer.conf']
928 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
929 output
= process
.communicate(input='')
930 except subprocess
.CalledProcessError
as exc
:
931 raise AssertionError('openssl req failed (%d): %s' % (exc
.returncode
, exc
.output
))
932 cmd
= ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix
+ '.csr', '-out', filePrefix
+ '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
935 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
936 output
= process
.communicate(input='')
937 except subprocess
.CalledProcessError
as exc
:
938 raise AssertionError('openssl x509 failed (%d): %s' % (exc
.returncode
, exc
.output
))
940 with
open(filePrefix
+ '.chain', 'w') as outFile
:
941 for inFileName
in [filePrefix
+ '.pem', 'ca.pem']:
942 with
open(inFileName
) as inFile
:
943 outFile
.write(inFile
.read())
945 cmd
= ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix
+ '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix
+ '.key', '-out', filePrefix
+ '.p12']
948 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
949 output
= process
.communicate(input='')
950 except subprocess
.CalledProcessError
as exc
:
951 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc
.returncode
, exc
.output
))
953 def checkMessageProxyProtocol(self
, receivedProxyPayload
, source
, destination
, isTCP
, values
=[], v6
=False, sourcePort
=None, destinationPort
=None):
954 proxy
= ProxyProtocol()
955 self
.assertTrue(proxy
.parseHeader(receivedProxyPayload
))
956 self
.assertEqual(proxy
.version
, 0x02)
957 self
.assertEqual(proxy
.command
, 0x01)
959 self
.assertEqual(proxy
.family
, 0x02)
961 self
.assertEqual(proxy
.family
, 0x01)
963 self
.assertEqual(proxy
.protocol
, 0x02)
965 self
.assertEqual(proxy
.protocol
, 0x01)
966 self
.assertGreater(proxy
.contentLen
, 0)
968 self
.assertTrue(proxy
.parseAddressesAndPorts(receivedProxyPayload
))
969 self
.assertEqual(proxy
.source
, source
)
970 self
.assertEqual(proxy
.destination
, destination
)
972 self
.assertEqual(proxy
.sourcePort
, sourcePort
)
974 self
.assertEqual(proxy
.destinationPort
, destinationPort
)
976 self
.assertEqual(proxy
.destinationPort
, self
._dnsDistPort
)
978 self
.assertTrue(proxy
.parseAdditionalValues(receivedProxyPayload
))
981 self
.assertEqual(proxy
.values
, values
)
984 def getDOHGetURL(cls
, baseurl
, query
, rawQuery
=False):
988 wire
= query
.to_wire()
989 param
= base64
.urlsafe_b64encode(wire
).decode('UTF8').rstrip('=')
990 return baseurl
+ "?dns=" + param
993 def openDOHConnection(cls
, port
, caFile
, timeout
=2.0):
995 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2
)
997 conn
.setopt(pycurl
.HTTPHEADER
, ["Content-type: application/dns-message",
998 "Accept: application/dns-message"])
1002 def sendDOHQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True, fromQueue
=None, toQueue
=None, conn
=None):
1003 url
= cls
.getDOHGetURL(baseurl
, query
, rawQuery
)
1006 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
1007 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1008 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1011 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1012 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1014 conn
.setopt(pycurl
.CAINFO
, caFile
)
1016 response_headers
= BytesIO()
1017 #conn.setopt(pycurl.VERBOSE, True)
1018 conn
.setopt(pycurl
.URL
, url
)
1019 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1021 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1022 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1026 toQueue
.put(response
, True, timeout
)
1028 cls
._toResponderQueue
.put(response
, True, timeout
)
1030 receivedQuery
= None
1032 cls
._response
_headers
= ''
1033 data
= conn
.perform_rb()
1034 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1035 if cls
._rcode
== 200 and not rawResponse
:
1036 message
= dns
.message
.from_wire(data
)
1042 if not fromQueue
.empty():
1043 receivedQuery
= fromQueue
.get(True, timeout
)
1045 if not cls
._fromResponderQueue
.empty():
1046 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1048 cls
._response
_headers
= response_headers
.getvalue()
1049 return (receivedQuery
, message
)
1052 def sendDOHPostQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True):
1054 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
1055 response_headers
= BytesIO()
1056 #conn.setopt(pycurl.VERBOSE, True)
1057 conn
.setopt(pycurl
.URL
, url
)
1058 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1059 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1060 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1062 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1063 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1065 conn
.setopt(pycurl
.CAINFO
, caFile
)
1067 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1068 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1069 conn
.setopt(pycurl
.POST
, True)
1072 data
= data
.to_wire()
1074 conn
.setopt(pycurl
.POSTFIELDS
, data
)
1077 cls
._toResponderQueue
.put(response
, True, timeout
)
1079 receivedQuery
= None
1081 cls
._response
_headers
= ''
1082 data
= conn
.perform_rb()
1083 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1084 if cls
._rcode
== 200 and not rawResponse
:
1085 message
= dns
.message
.from_wire(data
)
1089 if useQueue
and not cls
._fromResponderQueue
.empty():
1090 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1092 cls
._response
_headers
= response_headers
.getvalue()
1093 return (receivedQuery
, message
)
1095 def sendDOHQueryWrapper(self
, query
, response
, useQueue
=True):
1096 return self
.sendDOHQuery(self
._dohServerPort
, self
._serverName
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1098 def sendDOHWithNGHTTP2QueryWrapper(self
, query
, response
, useQueue
=True):
1099 return self
.sendDOHQuery(self
._dohWithNGHTTP
2ServerPort
, self
._serverName
, self
._dohWithNGHTTP
2BaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1101 def sendDOHWithH2OQueryWrapper(self
, query
, response
, useQueue
=True):
1102 return self
.sendDOHQuery(self
._dohWithH
2OServerPort
, self
._serverName
, self
._dohWithH
2OBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1104 def sendDOTQueryWrapper(self
, query
, response
, useQueue
=True):
1105 return self
.sendDOTQuery(self
._tlsServerPort
, self
._serverName
, query
, response
, self
._caCert
, useQueue
=useQueue
)
1107 def sendDOQQueryWrapper(self
, query
, response
, useQueue
=True):
1108 return self
.sendDOQQuery(self
._doqServerPort
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1110 def sendDOH3QueryWrapper(self
, query
, response
, useQueue
=True):
1111 return self
.sendDOH3Query(self
._doh
3ServerPort
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1113 def getDOQConnection(cls
, port
, caFile
=None, source
=None, source_port
=0):
1115 manager
= dns
.quic
.SyncQuicManager(
1119 return manager
.connect('127.0.0.1', port
, source
, source_port
)
1122 def sendDOQQuery(cls
, port
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None):
1126 toQueue
.put(response
, True, timeout
)
1128 cls
._toResponderQueue
.put(response
, True, timeout
)
1130 (message
, _
) = quic_query(query
, '127.0.0.1', timeout
, port
, verify
=caFile
, server_hostname
=serverName
)
1132 receivedQuery
= None
1136 if not fromQueue
.empty():
1137 receivedQuery
= fromQueue
.get(True, timeout
)
1139 if not cls
._fromResponderQueue
.empty():
1140 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1142 return (receivedQuery
, message
)
1145 def sendDOH3Query(cls
, port
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None, post
=False):
1149 toQueue
.put(response
, True, timeout
)
1151 cls
._toResponderQueue
.put(response
, True, timeout
)
1153 message
= doh3_query(query
, baseurl
, timeout
, port
, verify
=caFile
, server_hostname
=serverName
, post
=post
)
1155 receivedQuery
= None
1159 if not fromQueue
.empty():
1160 receivedQuery
= fromQueue
.get(True, timeout
)
1162 if not cls
._fromResponderQueue
.empty():
1163 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1165 return (receivedQuery
, message
)