import pycurl
from io import BytesIO
+from doqclient import quic_query
+from doh3client import doh3_query
+
from eqdnsmessage import AssertEqualDNSMessageMixin
from proxyprotocol import ProxyProtocol
except NameError:
pass
+def getWorkerID():
+ if not 'PYTEST_XDIST_WORKER' in os.environ:
+ return 0
+ workerName = os.environ['PYTEST_XDIST_WORKER']
+ return int(workerName[2:])
+
+workerPorts = {}
+
+def pickAvailablePort():
+ global workerPorts
+ workerID = getWorkerID()
+ if workerID in workerPorts:
+ port = workerPorts[workerID] + 1
+ else:
+ port = 11000 + (workerID * 1000)
+ workerPorts[workerID] = port
+ return port
class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
"""
from dnsdist on a separate queue, allowing the tests to check
that the queries sent from dnsdist were as expected.
"""
- _dnsDistPort = 5340
_dnsDistListeningAddr = "127.0.0.1"
- _testServerPort = 5350
_toResponderQueue = Queue()
_fromResponderQueue = Queue()
_queueTimeout = 1
"""
_config_params = ['_testServerPort']
_acl = ['127.0.0.1/32']
- _consolePort = 5199
_consoleKey = None
_healthCheckName = 'a.root-servers.net.'
_healthCheckCounter = 0
_UDPResponder = None
_TCPResponder = None
_extraStartupSleep = 0
+ _dnsDistPort = pickAvailablePort()
+ _consolePort = pickAvailablePort()
+ _testServerPort = pickAvailablePort()
@classmethod
def waitForTCPSocket(cls, ipaddress, port):
@classmethod
def startResponders(cls):
print("Launching responders..")
+ cls._testServerPort = pickAvailablePort()
cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
- cls._UDPResponder.setDaemon(True)
+ cls._UDPResponder.daemon = True
cls._UDPResponder.start()
cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
- cls._TCPResponder.setDaemon(True)
+ cls._TCPResponder.daemon = True
cls._TCPResponder.start()
cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
@classmethod
def startDNSDist(cls):
+ cls._dnsDistPort = pickAvailablePort()
+ cls._consolePort = pickAvailablePort()
+
print("Launching dnsdist..")
confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
params = tuple([getattr(cls, param) for param in cls._config_params])
print(params)
with open(confFile, 'w') as conf:
conf.write("-- Autogenerated by dnsdisttests.py\n")
+ conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
conf.write(cls._config_template % params)
conf.write("setSecurityPollSuffix('')")
else:
expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
if not cls._verboseMode and output != expectedOutput:
- raise AssertionError('dnsdist --check-config failed: %s' % output)
+ raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
with open(logFile, 'w') as fdLog:
@classmethod
def _ResponderIncrementCounter(cls):
- if threading.currentThread().name in cls._responsesCounter:
- cls._responsesCounter[threading.currentThread().name] += 1
+ if threading.current_thread().name in cls._responsesCounter:
+ cls._responsesCounter[threading.current_thread().name] += 1
else:
- cls._responsesCounter[threading.currentThread().name] = 1
+ cls._responsesCounter[threading.current_thread().name] = 1
@classmethod
def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
sock.close()
@classmethod
- def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
+ def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
ignoreTrailing = trailingDataResponse is True
- data = conn.recv(2)
+ try:
+ data = conn.recv(2)
+ except Exception as err:
+ data = None
+ print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
if not data:
conn.close()
return
conn.close()
return
- conn.send(struct.pack("!H", len(wire)))
+ wireLen = struct.pack("!H", len(wire))
+ if partialWrite:
+ for b in wireLen:
+ conn.send(bytes([b]))
+ time.sleep(0.5)
+ else:
+ conn.send(wireLen)
conn.send(wire)
while multipleResponses:
conn.close()
@classmethod
- def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'):
+ def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False):
cls._backgroundThreads[threading.get_native_id()] = True
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
if multipleConnections:
thread = threading.Thread(name='TCP Connection Handler',
target=cls.handleTCPConnection,
- args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback])
- thread.setDaemon(True)
+ args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
+ thread.daemon = True
thread.start()
else:
- cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback)
+ cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
sock.close()
except ssl.SSLEOFError as e:
print("Unexpected EOF: %s" % (e))
return
+ except Exception as err:
+ print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
+ return
dnsData = {}
# be careful, HTTP/2 headers and data might be in different recv() results
requestHeaders = None
while True:
- data = conn.recv(65535)
+ try:
+ data = conn.recv(65535)
+ except Exception as err:
+ data = None
+ print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
if not data:
break
thread = threading.Thread(name='DoH Connection Handler',
target=cls.handleDoHConnection,
args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
- thread.setDaemon(True)
+ thread.daemon = True
thread.start()
sock.close()
return sock
@classmethod
- def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
+ def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if timeout:
# 2.7.9+
if hasattr(ssl, 'create_default_context'):
sslctx = ssl.create_default_context(cafile=caCert)
+ if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(alpn)
sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
else:
sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
@classmethod
def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
- print("reading data")
message = None
data = sock.recv(2)
if data:
print(useQueue)
if useQueue and not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)
- print("Got from queue")
print(receivedQuery)
return (receivedQuery, message)
else:
receivedQuery = None
print(useQueue)
if useQueue and not cls._fromResponderQueue.empty():
- print("Got from queue")
print(receivedQuery)
receivedQuery = cls._fromResponderQueue.get(True, timeout)
else:
def checkResponseNoEDNS(self, expected, received):
self.checkMessageNoEDNS(expected, received)
- def generateNewCertificateAndKey(self):
+ @staticmethod
+ def generateNewCertificateAndKey(filePrefix):
# generate and sign a new cert
- cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
+ cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
output = None
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
output = process.communicate(input='')
except subprocess.CalledProcessError as exc:
raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
- cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
+ cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
output = None
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
except subprocess.CalledProcessError as exc:
raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
- with open('server.chain', 'w') as outFile:
- for inFileName in ['server.pem', 'ca.pem']:
+ with open(filePrefix + '.chain', 'w') as outFile:
+ for inFileName in [filePrefix + '.pem', 'ca.pem']:
with open(inFileName) as inFile:
outFile.write(inFile.read())
- cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
+ cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
output = None
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
return conn
@classmethod
- def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+ def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
url = cls.getDOHGetURL(baseurl, query, rawQuery)
- conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
- response_headers = BytesIO()
- #conn.setopt(pycurl.VERBOSE, True)
- conn.setopt(pycurl.URL, url)
- conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
+ if not conn:
+ conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+ # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)
+ response_headers = BytesIO()
+ #conn.setopt(pycurl.VERBOSE, True)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+ # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
def sendDOHQueryWrapper(self, query, response, useQueue=True):
return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+ def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
+ return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+
+ def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
+ return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+
def sendDOTQueryWrapper(self, query, response, useQueue=True):
return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
+
+ def sendDOQQueryWrapper(self, query, response, useQueue=True):
+ return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+
+ def sendDOH3QueryWrapper(self, query, response, useQueue=True):
+ return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+ @classmethod
+ def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
+
+ manager = dns.quic.SyncQuicManager(
+ verify_mode=caFile
+ )
+
+ return manager.connect('127.0.0.1', port, source, source_port)
+
+ @classmethod
+ def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
+
+ if response:
+ if toQueue:
+ toQueue.put(response, True, timeout)
+ else:
+ cls._toResponderQueue.put(response, True, timeout)
+
+ (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
+
+ receivedQuery = None
+
+ if useQueue:
+ if fromQueue:
+ if not fromQueue.empty():
+ receivedQuery = fromQueue.get(True, timeout)
+ else:
+ if not cls._fromResponderQueue.empty():
+ receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+ return (receivedQuery, message)
+
+ @classmethod
+ def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
+
+ if response:
+ if toQueue:
+ toQueue.put(response, True, timeout)
+ else:
+ cls._toResponderQueue.put(response, True, timeout)
+
+ message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
+
+ receivedQuery = None
+
+ if useQueue:
+ if fromQueue:
+ if not fromQueue.empty():
+ receivedQuery = fromQueue.get(True, timeout)
+ else:
+ if not cls._fromResponderQueue.empty():
+ receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+ return (receivedQuery, message)