16 import clientsubnetoption
29 from io
import BytesIO
31 from doqclient
import quic_query
32 from doh3client
import doh3_query
34 from eqdnsmessage
import AssertEqualDNSMessageMixin
35 from proxyprotocol
import ProxyProtocol
37 # Python2/3 compatibility hacks
39 from queue
import Queue
41 from Queue
import Queue
49 if not 'PYTEST_XDIST_WORKER' in os
.environ
:
51 workerName
= os
.environ
['PYTEST_XDIST_WORKER']
52 return int(workerName
[2:])
56 def pickAvailablePort():
58 workerID
= getWorkerID()
59 if workerID
in workerPorts
:
60 port
= workerPorts
[workerID
] + 1
62 port
= 11000 + (workerID
* 1000)
63 workerPorts
[workerID
] = port
66 class DNSDistTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
75 _dnsDistListeningAddr
= "127.0.0.1"
76 _toResponderQueue
= Queue()
77 _fromResponderQueue
= Queue()
80 _responsesCounter
= {}
81 _config_template
= """
83 _config_params
= ['_testServerPort']
84 _acl
= ['127.0.0.1/32']
86 _healthCheckName
= 'a.root-servers.net.'
87 _healthCheckCounter
= 0
88 _answerUnexpected
= True
89 _checkConfigExpectedOutput
= None
92 _skipListeningOnCL
= False
93 _alternateListeningAddr
= None
94 _alternateListeningPort
= None
95 _backgroundThreads
= {}
98 _extraStartupSleep
= 0
99 _dnsDistPort
= pickAvailablePort()
100 _consolePort
= pickAvailablePort()
101 _testServerPort
= pickAvailablePort()
104 def waitForTCPSocket(cls
, ipaddress
, port
):
105 for try_number
in range(0, 20):
107 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
109 sock
.connect((ipaddress
, port
))
112 except Exception as err
:
113 if err
.errno
!= errno
.ECONNREFUSED
:
114 print(f
'Error occurred: {try_number} {err}', file=sys
.stderr
)
116 # We assume the dnsdist instance does not listen. That's fine.
119 def startResponders(cls
):
120 print("Launching responders..")
121 cls
._testServerPort
= pickAvailablePort()
123 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
124 cls
._UDPResponder
.daemon
= True
125 cls
._UDPResponder
.start()
126 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
127 cls
._TCPResponder
.daemon
= True
128 cls
._TCPResponder
.start()
129 cls
.waitForTCPSocket("127.0.0.1", cls
._testServerPort
);
132 def startDNSDist(cls
):
133 cls
._dnsDistPort
= pickAvailablePort()
134 cls
._consolePort
= pickAvailablePort()
136 print("Launching dnsdist..")
137 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
138 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
140 with
open(confFile
, 'w') as conf
:
141 conf
.write("-- Autogenerated by dnsdisttests.py\n")
142 conf
.write(f
"-- dnsdist will listen on {cls._dnsDistPort}")
143 conf
.write(cls
._config
_template
% params
)
144 conf
.write("setSecurityPollSuffix('')")
146 if cls
._skipListeningOnCL
:
147 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
]
149 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
150 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
153 dnsdistcmd
.append('-v')
155 if 'LD_LIBRARY_PATH' in os
.environ
:
156 dnsdistcmd
.insert(0, 'LD_LIBRARY_PATH=' + os
.environ
['LD_LIBRARY_PATH'])
157 dnsdistcmd
.insert(0, 'sudo')
160 dnsdistcmd
.extend(['--acl', acl
])
161 print(' '.join(dnsdistcmd
))
163 # validate config with --check-config, which sets client=true, possibly exposing bugs.
164 testcmd
= dnsdistcmd
+ ['--check-config']
166 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
167 except subprocess
.CalledProcessError
as exc
:
168 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
169 if cls
._checkConfigExpectedOutput
is not None:
170 expectedOutput
= cls
._checkConfigExpectedOutput
172 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
173 if not cls
._verboseMode
and output
!= expectedOutput
:
174 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output
, expectedOutput
))
176 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
177 with
open(logFile
, 'w') as fdLog
:
178 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
180 if cls
._alternateListeningAddr
and cls
._alternateListeningPort
:
181 cls
.waitForTCPSocket(cls
._alternateListeningAddr
, cls
._alternateListeningPort
)
183 cls
.waitForTCPSocket(cls
._dnsDistListeningAddr
, cls
._dnsDistPort
)
185 if cls
._dnsdist
.poll() is not None:
186 print(f
"\n*** startDNSDist log for {logFile} ***")
187 with
open(logFile
, 'r') as fdLog
:
189 print(f
"*** End startDNSDist log for {logFile} ***")
190 raise AssertionError('%s failed (%d)' % (dnsdistcmd
, cls
._dnsdist
.returncode
))
191 time
.sleep(cls
._extraStartupSleep
)
194 def setUpSockets(cls
):
195 print("Setting up UDP socket..")
196 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
197 cls
._sock
.settimeout(2.0)
198 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
201 def killProcess(cls
, p
):
202 # Don't try to kill it if it's already dead
203 if p
.poll() is not None:
207 for count
in range(20):
213 print("kill...", p
, file=sys
.stderr
)
217 # There is a race-condition with the poll() and
218 # kill() statements, when the process is dead on the
219 # kill(), this is fine
220 if e
.errno
!= errno
.ESRCH
:
226 cls
.startResponders()
230 print("Launching tests..")
233 def tearDownClass(cls
):
235 # tell the background threads to stop, if any
236 for backgroundThread
in cls
._backgroundThreads
:
237 cls
._backgroundThreads
[backgroundThread
] = False
238 cls
.killProcess(cls
._dnsdist
)
241 def _ResponderIncrementCounter(cls
):
242 if threading
.current_thread().name
in cls
._responsesCounter
:
243 cls
._responsesCounter
[threading
.current_thread().name
] += 1
245 cls
._responsesCounter
[threading
.current_thread().name
] = 1
248 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
250 if len(request
.question
) != 1:
251 print("Skipping query with question count %d" % (len(request
.question
)))
253 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
255 cls
._healthCheckCounter
+= 1
256 response
= dns
.message
.make_response(request
)
258 cls
._ResponderIncrementCounter
()
259 if not fromQueue
.empty():
260 toQueue
.put(request
, True, cls
._queueTimeout
)
261 response
= fromQueue
.get(True, cls
._queueTimeout
)
263 response
= copy
.copy(response
)
264 response
.id = request
.id
266 if synthesize
is not None:
267 response
= dns
.message
.make_response(request
)
268 response
.set_rcode(synthesize
)
271 if cls
._answerUnexpected
:
272 response
= dns
.message
.make_response(request
)
273 response
.set_rcode(dns
.rcode
.SERVFAIL
)
278 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
279 cls
._backgroundThreads
[threading
.get_native_id()] = True
280 # trailingDataResponse=True means "ignore trailing data".
281 # Other values are either False (meaning "raise an exception")
282 # or are interpreted as a response RCODE for queries with trailing data.
283 # callback is invoked for every -even healthcheck ones- query and should return a raw response
284 ignoreTrailing
= trailingDataResponse
is True
286 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
287 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
288 sock
.bind(("127.0.0.1", port
))
292 data
, addr
= sock
.recvfrom(4096)
293 except socket
.timeout
:
294 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
295 del cls
._backgroundThreads
[threading
.get_native_id()]
302 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
303 except dns
.message
.TrailingJunk
as e
:
304 print('trailing data exception in UDPResponder')
305 if trailingDataResponse
is False or forceRcode
is True:
307 print("UDP query with trailing data, synthesizing response")
308 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
309 forceRcode
= trailingDataResponse
313 wire
= callback(request
)
316 forceRcode
= dns
.rcode
.BADVERS
317 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
319 wire
= response
.to_wire()
324 sock
.sendto(wire
, addr
)
329 def handleTCPConnection(cls
, conn
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, partialWrite
=False):
330 ignoreTrailing
= trailingDataResponse
is True
333 except Exception as err
:
335 print(f
'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
340 (datalen
,) = struct
.unpack("!H", data
)
341 data
= conn
.recv(datalen
)
344 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
345 except dns
.message
.TrailingJunk
as e
:
346 if trailingDataResponse
is False or forceRcode
is True:
348 print("TCP query with trailing data, synthesizing response")
349 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
350 forceRcode
= trailingDataResponse
353 wire
= callback(request
)
356 forceRcode
= dns
.rcode
.BADVERS
357 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
359 wire
= response
.to_wire(max_size
=65535)
365 wireLen
= struct
.pack("!H", len(wire
))
368 conn
.send(bytes([b
]))
374 while multipleResponses
:
375 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
376 # otherwise we might read responses intended for the next connection
377 if fromQueue
.empty():
380 response
= fromQueue
.get(False)
384 response
= copy
.copy(response
)
385 response
.id = request
.id
386 wire
= response
.to_wire(max_size
=65535)
388 conn
.send(struct
.pack("!H", len(wire
)))
390 except socket
.error
as e
:
391 # some of the tests are going to close
392 # the connection on us, just deal with it
398 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, multipleConnections
=False, listeningAddr
='127.0.0.1', partialWrite
=False):
399 cls
._backgroundThreads
[threading
.get_native_id()] = True
400 # trailingDataResponse=True means "ignore trailing data".
401 # Other values are either False (meaning "raise an exception")
402 # or are interpreted as a response RCODE for queries with trailing data.
403 # callback is invoked for every -even healthcheck ones- query and should return a raw response
405 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
406 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
407 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
409 sock
.bind((listeningAddr
, port
))
410 except socket
.error
as e
:
411 print("Error binding in the TCP responder: %s" % str(e
))
417 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
421 (conn
, _
) = sock
.accept()
424 except ConnectionResetError
:
426 except socket
.timeout
:
427 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
428 del cls
._backgroundThreads
[threading
.get_native_id()]
434 if multipleConnections
:
435 thread
= threading
.Thread(name
='TCP Connection Handler',
436 target
=cls
.handleTCPConnection
,
437 args
=[conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
])
441 cls
.handleTCPConnection(conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, partialWrite
)
446 def handleDoHConnection(cls
, config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
):
447 ignoreTrailing
= trailingDataResponse
is True
449 h2conn
= h2
.connection
.H2Connection(config
=config
)
450 h2conn
.initiate_connection()
451 conn
.sendall(h2conn
.data_to_send())
452 except ssl
.SSLEOFError
as e
:
453 print("Unexpected EOF: %s" % (e
))
455 except Exception as err
:
456 print(f
'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
462 # try to read the entire Proxy Protocol header
463 proxy
= ProxyProtocol()
464 header
= conn
.recv(proxy
.HEADER_SIZE
)
466 print('unable to get header')
470 if not proxy
.parseHeader(header
):
471 print('unable to parse header')
476 proxyContent
= conn
.recv(proxy
.contentLen
)
478 print('unable to get content')
482 payload
= header
+ proxyContent
483 toQueue
.put(payload
, True, cls
._queueTimeout
)
485 # be careful, HTTP/2 headers and data might be in different recv() results
486 requestHeaders
= None
489 data
= conn
.recv(65535)
490 except Exception as err
:
492 print(f
'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
496 events
= h2conn
.receive_data(data
)
498 if isinstance(event
, h2
.events
.RequestReceived
):
499 requestHeaders
= event
.headers
500 if isinstance(event
, h2
.events
.DataReceived
):
501 h2conn
.acknowledge_received_data(event
.flow_controlled_length
, event
.stream_id
)
502 if not event
.stream_id
in dnsData
:
503 dnsData
[event
.stream_id
] = b
''
504 dnsData
[event
.stream_id
] = dnsData
[event
.stream_id
] + (event
.data
)
505 if event
.stream_ended
:
509 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=ignoreTrailing
)
510 except dns
.message
.TrailingJunk
as e
:
511 if trailingDataResponse
is False or forceRcode
is True:
513 print("DOH query with trailing data, synthesizing response")
514 request
= dns
.message
.from_wire(dnsData
[event
.stream_id
], ignore_trailing
=True)
515 forceRcode
= trailingDataResponse
518 status
, wire
= callback(request
, requestHeaders
, fromQueue
, toQueue
)
520 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
522 wire
= response
.to_wire(max_size
=65535)
530 (':status', str(status
)),
531 ('content-length', str(len(wire
))),
532 ('content-type', 'application/dns-message'),
534 h2conn
.send_headers(stream_id
=event
.stream_id
, headers
=headers
)
535 h2conn
.send_data(stream_id
=event
.stream_id
, data
=wire
, end_stream
=True)
537 data_to_send
= h2conn
.data_to_send()
539 conn
.sendall(data_to_send
)
548 def DOHResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None, tlsContext
=None, useProxyProtocol
=False):
549 cls
._backgroundThreads
[threading
.get_native_id()] = True
550 # trailingDataResponse=True means "ignore trailing data".
551 # Other values are either False (meaning "raise an exception")
552 # or are interpreted as a response RCODE for queries with trailing data.
553 # callback is invoked for every -even healthcheck ones- query and should return a raw response
555 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
556 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
557 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
559 sock
.bind(("127.0.0.1", port
))
560 except socket
.error
as e
:
561 print("Error binding in the TCP responder: %s" % str(e
))
567 sock
= tlsContext
.wrap_socket(sock
, server_side
=True)
569 config
= h2
.config
.H2Configuration(client_side
=False)
573 (conn
, _
) = sock
.accept()
576 except ConnectionResetError
:
578 except socket
.timeout
:
579 if cls
._backgroundThreads
.get(threading
.get_native_id(), False) == False:
580 del cls
._backgroundThreads
[threading
.get_native_id()]
586 thread
= threading
.Thread(name
='DoH Connection Handler',
587 target
=cls
.handleDoHConnection
,
588 args
=[config
, conn
, fromQueue
, toQueue
, trailingDataResponse
, multipleResponses
, callback
, tlsContext
, useProxyProtocol
])
595 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
596 if useQueue
and response
is not None:
597 cls
._toResponderQueue
.put(response
, True, timeout
)
600 cls
._sock
.settimeout(timeout
)
604 query
= query
.to_wire()
605 cls
._sock
.send(query
)
606 data
= cls
._sock
.recv(4096)
607 except socket
.timeout
:
611 cls
._sock
.settimeout(None)
615 if useQueue
and not cls
._fromResponderQueue
.empty():
616 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
618 message
= dns
.message
.from_wire(data
)
619 return (receivedQuery
, message
)
622 def openTCPConnection(cls
, timeout
=None, port
=None):
623 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
624 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
626 sock
.settimeout(timeout
)
629 port
= cls
._dnsDistPort
631 sock
.connect(("127.0.0.1", port
))
635 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None, alpn
=[]):
636 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
637 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
639 sock
.settimeout(timeout
)
642 if hasattr(ssl
, 'create_default_context'):
643 sslctx
= ssl
.create_default_context(cafile
=caCert
)
644 if len(alpn
)> 0 and hasattr(sslctx
, 'set_alpn_protocols'):
645 sslctx
.set_alpn_protocols(alpn
)
646 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
648 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
650 sslsock
.connect(("127.0.0.1", port
))
654 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
656 wire
= query
.to_wire()
661 cls
._toResponderQueue
.put(response
, True, timeout
)
663 sock
.send(struct
.pack("!H", len(wire
)))
667 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
671 (datalen
,) = struct
.unpack("!H", data
)
673 data
= sock
.recv(datalen
)
676 message
= dns
.message
.from_wire(data
)
679 if useQueue
and not cls
._fromResponderQueue
.empty():
680 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
682 return (receivedQuery
, message
)
688 def sendDOTQuery(cls
, port
, serverName
, query
, response
, caFile
, useQueue
=True):
689 conn
= cls
.openTLSConnection(port
, serverName
, caFile
)
690 cls
.sendTCPQueryOverConnection(conn
, query
, response
=response
)
692 return cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
693 return None, cls
.recvTCPResponseOverConnection(conn
, useQueue
=useQueue
)
696 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
699 cls
._toResponderQueue
.put(response
, True, timeout
)
702 sock
= cls
.openTCPConnection(timeout
)
703 except socket
.timeout
as e
:
704 print("Timeout while opening TCP connection: %s" % (str(e
)))
708 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
, timeout
=timeout
)
709 message
= cls
.recvTCPResponseOverConnection(sock
, timeout
=timeout
)
710 except socket
.timeout
as e
:
711 print("Timeout while sending or receiving TCP data: %s" % (str(e
)))
712 except socket
.error
as e
:
713 print("Network error: %s" % (str(e
)))
719 if useQueue
and not cls
._fromResponderQueue
.empty():
721 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
723 print("queue is empty")
725 return (receivedQuery
, message
)
728 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
730 for response
in responses
:
731 cls
._toResponderQueue
.put(response
, True, timeout
)
732 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
733 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
735 sock
.settimeout(timeout
)
737 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
742 wire
= query
.to_wire()
746 sock
.send(struct
.pack("!H", len(wire
)))
752 (datalen
,) = struct
.unpack("!H", data
)
753 data
= sock
.recv(datalen
)
754 messages
.append(dns
.message
.from_wire(data
))
756 except socket
.timeout
as e
:
757 print("Timeout while receiving multiple TCP responses: %s" % (str(e
)))
758 except socket
.error
as e
:
759 print("Network error: %s" % (str(e
)))
764 if useQueue
and not cls
._fromResponderQueue
.empty():
765 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
766 return (receivedQuery
, messages
)
769 # This function is called before every test
771 # Clear the responses counters
772 self
._responsesCounter
.clear()
774 self
._healthCheckCounter
= 0
776 # Make sure the queues are empty, in case
777 # a previous test failed
778 self
.clearResponderQueues()
780 super(DNSDistTest
, self
).setUp()
783 def clearToResponderQueue(cls
):
784 while not cls
._toResponderQueue
.empty():
785 cls
._toResponderQueue
.get(False)
788 def clearFromResponderQueue(cls
):
789 while not cls
._fromResponderQueue
.empty():
790 cls
._fromResponderQueue
.get(False)
793 def clearResponderQueues(cls
):
794 cls
.clearToResponderQueue()
795 cls
.clearFromResponderQueue()
798 def generateConsoleKey():
799 return libnacl
.utils
.salsa_key()
802 def _encryptConsole(cls
, command
, nonce
):
803 command
= command
.encode('UTF-8')
804 if cls
._consoleKey
is None:
806 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
809 def _decryptConsole(cls
, command
, nonce
):
810 if cls
._consoleKey
is None:
813 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
814 return result
.decode('UTF-8')
817 def sendConsoleCommand(cls
, command
, timeout
=5.0):
818 ourNonce
= libnacl
.utils
.rand_nonce()
820 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
821 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
823 sock
.settimeout(timeout
)
825 sock
.connect(("127.0.0.1", cls
._consolePort
))
827 theirNonce
= sock
.recv(len(ourNonce
))
828 if len(theirNonce
) != len(ourNonce
):
829 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
830 if len(theirNonce
) == 0:
831 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
834 halfNonceSize
= int(len(ourNonce
) / 2)
835 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
836 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
837 msg
= cls
._encryptConsole
(command
, writingNonce
)
838 sock
.send(struct
.pack("!I", len(msg
)))
842 raise socket
.error("Got EOF while reading the response size")
844 (responseLen
,) = struct
.unpack("!I", data
)
845 data
= sock
.recv(responseLen
)
846 response
= cls
._decryptConsole
(data
, readingNonce
)
850 def compareOptions(self
, a
, b
):
851 self
.assertEqual(len(a
), len(b
))
852 for idx
in range(len(a
)):
853 self
.assertEqual(a
[idx
], b
[idx
])
855 def checkMessageNoEDNS(self
, expected
, received
):
856 self
.assertEqual(expected
, received
)
857 self
.assertEqual(received
.edns
, -1)
858 self
.assertEqual(len(received
.options
), 0)
860 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
861 self
.assertEqual(expected
, received
)
862 self
.assertEqual(received
.edns
, 0)
863 self
.assertEqual(expected
.payload
, received
.payload
)
865 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
866 self
.assertEqual(expected
, received
)
867 self
.assertEqual(received
.edns
, 0)
868 self
.assertEqual(expected
.payload
, received
.payload
)
869 self
.assertEqual(len(received
.options
), withCookies
)
871 for option
in received
.options
:
872 self
.assertEqual(option
.otype
, 10)
874 for option
in received
.options
:
875 self
.assertNotEqual(option
.otype
, 10)
877 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
878 self
.assertEqual(expected
, received
)
879 self
.assertEqual(received
.edns
, 0)
880 self
.assertEqual(expected
.payload
, received
.payload
)
881 self
.assertEqual(len(received
.options
), 1 + additionalOptions
)
883 for option
in received
.options
:
884 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
887 self
.assertNotEqual(additionalOptions
, 0)
889 self
.compareOptions(expected
.options
, received
.options
)
890 self
.assertTrue(hasECS
)
892 def checkMessageEDNS(self
, expected
, received
):
893 self
.assertEqual(expected
, received
)
894 self
.assertEqual(received
.edns
, 0)
895 self
.assertEqual(expected
.payload
, received
.payload
)
896 self
.assertEqual(len(expected
.options
), len(received
.options
))
897 self
.compareOptions(expected
.options
, received
.options
)
899 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
900 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
902 def checkQueryEDNS(self
, expected
, received
):
903 self
.checkMessageEDNS(expected
, received
)
905 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
906 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
908 def checkQueryEDNSWithoutECS(self
, expected
, received
):
909 self
.checkMessageEDNSWithoutECS(expected
, received
)
911 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
912 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
914 def checkQueryNoEDNS(self
, expected
, received
):
915 self
.checkMessageNoEDNS(expected
, received
)
917 def checkResponseNoEDNS(self
, expected
, received
):
918 self
.checkMessageNoEDNS(expected
, received
)
921 def generateNewCertificateAndKey(filePrefix
):
922 # generate and sign a new cert
923 cmd
= ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix
+ '.key', '-out', filePrefix
+ '.csr', '-config', 'configServer.conf']
926 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
927 output
= process
.communicate(input='')
928 except subprocess
.CalledProcessError
as exc
:
929 raise AssertionError('openssl req failed (%d): %s' % (exc
.returncode
, exc
.output
))
930 cmd
= ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix
+ '.csr', '-out', filePrefix
+ '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
933 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
934 output
= process
.communicate(input='')
935 except subprocess
.CalledProcessError
as exc
:
936 raise AssertionError('openssl x509 failed (%d): %s' % (exc
.returncode
, exc
.output
))
938 with
open(filePrefix
+ '.chain', 'w') as outFile
:
939 for inFileName
in [filePrefix
+ '.pem', 'ca.pem']:
940 with
open(inFileName
) as inFile
:
941 outFile
.write(inFile
.read())
943 cmd
= ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix
+ '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix
+ '.key', '-out', filePrefix
+ '.p12']
946 process
= subprocess
.Popen(cmd
, stdout
=subprocess
.PIPE
, stdin
=subprocess
.PIPE
, stderr
=subprocess
.STDOUT
, close_fds
=True)
947 output
= process
.communicate(input='')
948 except subprocess
.CalledProcessError
as exc
:
949 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc
.returncode
, exc
.output
))
951 def checkMessageProxyProtocol(self
, receivedProxyPayload
, source
, destination
, isTCP
, values
=[], v6
=False, sourcePort
=None, destinationPort
=None):
952 proxy
= ProxyProtocol()
953 self
.assertTrue(proxy
.parseHeader(receivedProxyPayload
))
954 self
.assertEqual(proxy
.version
, 0x02)
955 self
.assertEqual(proxy
.command
, 0x01)
957 self
.assertEqual(proxy
.family
, 0x02)
959 self
.assertEqual(proxy
.family
, 0x01)
961 self
.assertEqual(proxy
.protocol
, 0x02)
963 self
.assertEqual(proxy
.protocol
, 0x01)
964 self
.assertGreater(proxy
.contentLen
, 0)
966 self
.assertTrue(proxy
.parseAddressesAndPorts(receivedProxyPayload
))
967 self
.assertEqual(proxy
.source
, source
)
968 self
.assertEqual(proxy
.destination
, destination
)
970 self
.assertEqual(proxy
.sourcePort
, sourcePort
)
972 self
.assertEqual(proxy
.destinationPort
, destinationPort
)
974 self
.assertEqual(proxy
.destinationPort
, self
._dnsDistPort
)
976 self
.assertTrue(proxy
.parseAdditionalValues(receivedProxyPayload
))
979 self
.assertEqual(proxy
.values
, values
)
982 def getDOHGetURL(cls
, baseurl
, query
, rawQuery
=False):
986 wire
= query
.to_wire()
987 param
= base64
.urlsafe_b64encode(wire
).decode('UTF8').rstrip('=')
988 return baseurl
+ "?dns=" + param
991 def openDOHConnection(cls
, port
, caFile
, timeout
=2.0):
993 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2
)
995 conn
.setopt(pycurl
.HTTPHEADER
, ["Content-type: application/dns-message",
996 "Accept: application/dns-message"])
1000 def sendDOHQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True, fromQueue
=None, toQueue
=None, conn
=None):
1001 url
= cls
.getDOHGetURL(baseurl
, query
, rawQuery
)
1004 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
1005 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1006 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1009 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1010 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1012 conn
.setopt(pycurl
.CAINFO
, caFile
)
1014 response_headers
= BytesIO()
1015 #conn.setopt(pycurl.VERBOSE, True)
1016 conn
.setopt(pycurl
.URL
, url
)
1017 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1019 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1020 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1024 toQueue
.put(response
, True, timeout
)
1026 cls
._toResponderQueue
.put(response
, True, timeout
)
1028 receivedQuery
= None
1030 cls
._response
_headers
= ''
1031 data
= conn
.perform_rb()
1032 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1033 if cls
._rcode
== 200 and not rawResponse
:
1034 message
= dns
.message
.from_wire(data
)
1040 if not fromQueue
.empty():
1041 receivedQuery
= fromQueue
.get(True, timeout
)
1043 if not cls
._fromResponderQueue
.empty():
1044 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1046 cls
._response
_headers
= response_headers
.getvalue()
1047 return (receivedQuery
, message
)
1050 def sendDOHPostQuery(cls
, port
, servername
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, rawResponse
=False, customHeaders
=[], useHTTPS
=True):
1052 conn
= cls
.openDOHConnection(port
, caFile
=caFile
, timeout
=timeout
)
1053 response_headers
= BytesIO()
1054 #conn.setopt(pycurl.VERBOSE, True)
1055 conn
.setopt(pycurl
.URL
, url
)
1056 conn
.setopt(pycurl
.RESOLVE
, ["%s:%d:127.0.0.1" % (servername
, port
)])
1057 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1058 conn
.setopt(pycurl
.HTTP_VERSION
, pycurl
.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE
)
1060 conn
.setopt(pycurl
.SSL_VERIFYPEER
, 1)
1061 conn
.setopt(pycurl
.SSL_VERIFYHOST
, 2)
1063 conn
.setopt(pycurl
.CAINFO
, caFile
)
1065 conn
.setopt(pycurl
.HTTPHEADER
, customHeaders
)
1066 conn
.setopt(pycurl
.HEADERFUNCTION
, response_headers
.write
)
1067 conn
.setopt(pycurl
.POST
, True)
1070 data
= data
.to_wire()
1072 conn
.setopt(pycurl
.POSTFIELDS
, data
)
1075 cls
._toResponderQueue
.put(response
, True, timeout
)
1077 receivedQuery
= None
1079 cls
._response
_headers
= ''
1080 data
= conn
.perform_rb()
1081 cls
._rcode
= conn
.getinfo(pycurl
.RESPONSE_CODE
)
1082 if cls
._rcode
== 200 and not rawResponse
:
1083 message
= dns
.message
.from_wire(data
)
1087 if useQueue
and not cls
._fromResponderQueue
.empty():
1088 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1090 cls
._response
_headers
= response_headers
.getvalue()
1091 return (receivedQuery
, message
)
1093 def sendDOHQueryWrapper(self
, query
, response
, useQueue
=True):
1094 return self
.sendDOHQuery(self
._dohServerPort
, self
._serverName
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1096 def sendDOHWithNGHTTP2QueryWrapper(self
, query
, response
, useQueue
=True):
1097 return self
.sendDOHQuery(self
._dohWithNGHTTP
2ServerPort
, self
._serverName
, self
._dohWithNGHTTP
2BaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1099 def sendDOHWithH2OQueryWrapper(self
, query
, response
, useQueue
=True):
1100 return self
.sendDOHQuery(self
._dohWithH
2OServerPort
, self
._serverName
, self
._dohWithH
2OBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
)
1102 def sendDOTQueryWrapper(self
, query
, response
, useQueue
=True):
1103 return self
.sendDOTQuery(self
._tlsServerPort
, self
._serverName
, query
, response
, self
._caCert
, useQueue
=useQueue
)
1105 def sendDOQQueryWrapper(self
, query
, response
, useQueue
=True):
1106 return self
.sendDOQQuery(self
._doqServerPort
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1108 def sendDOH3QueryWrapper(self
, query
, response
, useQueue
=True):
1109 return self
.sendDOH3Query(self
._doh
3ServerPort
, self
._dohBaseURL
, query
, response
=response
, caFile
=self
._caCert
, useQueue
=useQueue
, serverName
=self
._serverName
)
1111 def getDOQConnection(cls
, port
, caFile
=None, source
=None, source_port
=0):
1113 manager
= dns
.quic
.SyncQuicManager(
1117 return manager
.connect('127.0.0.1', port
, source
, source_port
)
1120 def sendDOQQuery(cls
, port
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None):
1124 toQueue
.put(response
, True, timeout
)
1126 cls
._toResponderQueue
.put(response
, True, timeout
)
1128 message
= quic_query(query
, '127.0.0.1', timeout
, port
, verify
=caFile
, server_hostname
=serverName
)
1130 receivedQuery
= None
1134 if not fromQueue
.empty():
1135 receivedQuery
= fromQueue
.get(True, timeout
)
1137 if not cls
._fromResponderQueue
.empty():
1138 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1140 return (receivedQuery
, message
)
1143 def sendDOH3Query(cls
, port
, baseurl
, query
, response
=None, timeout
=2.0, caFile
=None, useQueue
=True, rawQuery
=False, fromQueue
=None, toQueue
=None, connection
=None, serverName
=None, post
=False):
1147 toQueue
.put(response
, True, timeout
)
1149 cls
._toResponderQueue
.put(response
, True, timeout
)
1151 message
= doh3_query(query
, baseurl
, timeout
, port
, verify
=caFile
, server_hostname
=serverName
, post
=post
)
1153 receivedQuery
= None
1157 if not fromQueue
.empty():
1158 receivedQuery
= fromQueue
.get(True, timeout
)
1160 if not cls
._fromResponderQueue
.empty():
1161 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
1163 return (receivedQuery
, message
)