]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Allow listening on a different addr for DoQ/DoH3 tests
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Jan 2026 13:57:37 +0000 (14:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Jan 2026 14:53:14 +0000 (15:53 +0100)
Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/doh3client.py

index 24e2c22ed05bc7363c6636d986f1bd379dd62a1c..db03842835fb70a7f77813ad782fc8fedb9b889b 100644 (file)
@@ -28,8 +28,8 @@ import h2.config
 import pycurl
 from io import BytesIO
 
-from doqclient import quic_query
 from doh3client import doh3_query
+import doqclient
 
 from eqdnsmessage import AssertEqualDNSMessageMixin
 from proxyprotocol import ProxyProtocol
@@ -227,7 +227,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         print("Setting up UDP socket..")
         cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         cls._sock.settimeout(2.0)
-        cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
+        cls._sock.connect((cls._dnsDistListeningAddr, cls._dnsDistPort))
 
     @classmethod
     def killProcess(cls, p):
@@ -671,7 +671,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         if not port:
           port = cls._dnsDistPort
 
-        sock.connect(("127.0.0.1", port))
+        sock.connect((cls._dnsDistListeningAddr, port))
         return sock
 
     @classmethod
@@ -691,7 +691,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         else:
             sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
 
-        sslsock.connect(("127.0.0.1", port))
+        sslsock.connect((cls._dnsDistListeningAddr, port))
         return sslsock
 
     @classmethod
@@ -778,7 +778,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         if timeout:
             sock.settimeout(timeout)
 
-        sock.connect(("127.0.0.1", cls._dnsDistPort))
+        sock.connect((cls._dnsDistListeningAddr, cls._dnsDistPort))
         messages = []
 
         try:
@@ -1064,7 +1064,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         response_headers = BytesIO()
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:%s" % (servername, port, cls._dnsDistListeningAddr)])
 
         conn.setopt(pycurl.HTTPHEADER, customHeaders)
         conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
@@ -1103,7 +1103,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         response_headers = BytesIO()
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:%s" % (servername, port, cls._dnsDistListeningAddr)])
         # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
         conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
         if useHTTPS:
@@ -1152,11 +1152,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     def sendDOTQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
         return self.sendDOTQuery(self._tlsServerPort, self._serverName if not serverName else serverName, query, response, self._caCert, useQueue=useQueue, timeout=timeout)
 
-    def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
-        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
+    def sendDOQQueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None, passExceptions=False):
+        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout, passExceptions=passExceptions)
+
+    def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None, passExceptions=False):
+        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout, passExceptions=passExceptions)
 
-    def sendDOH3QueryWrapper(self, query, response, useQueue=True, timeout=2, serverName=None):
-        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName if not serverName else serverName, timeout=timeout)
     @classmethod
     def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
 
@@ -1164,10 +1165,10 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             verify_mode=caFile
         )
 
-        return manager.connect('127.0.0.1', port, source, source_port)
+        return manager.connect(cls._dnsDistListeningAddr, port, source, source_port)
 
     @classmethod
-    def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
+    def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, passExceptions=False):
 
         if response:
             if toQueue:
@@ -1175,7 +1176,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             else:
                 cls._toResponderQueue.put(response, True, timeout)
 
-        (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
+        try:
+            (message, _) = doqclient.quic_query(query, cls._dnsDistListeningAddr, timeout, port, verify=caFile, server_hostname=serverName)
+        except doqclient.StreamResetError as e:
+            if passExceptions:
+                raise
+            return (None, None)
 
         receivedQuery = None
 
@@ -1190,7 +1196,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return (receivedQuery, message)
 
     @classmethod
-    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None, rawResponse=False):
+    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None, rawResponse=False, passExceptions=False):
 
         if response:
             if toQueue:
@@ -1199,9 +1205,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 cls._toResponderQueue.put(response, True, timeout)
 
         if rawResponse:
-          return doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
+          return doh3_query(query, cls._dnsDistListeningAddr, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
 
-        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
+        try:
+            message = doh3_query(query, cls._dnsDistListeningAddr, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
+        except doqclient.StreamResetError as e:
+          if passExceptions:
+                raise
+          return (None, None)
 
         receivedQuery = None
 
index f0b2c428f78b9b283251064973b6cedef9afb058..2269d56808e62a03f2c047f681c32aeb53918e00 100644 (file)
@@ -180,6 +180,7 @@ async def perform_http_request(
 
 async def async_h3_query(
     configuration: QuicConfiguration,
+    host: str,
     baseurl: str,
     port: int,
     query: dns.message,
@@ -193,7 +194,7 @@ async def async_h3_query(
     if not post:
         url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
     async with connect(
-        "127.0.0.1",
+        host,
         port,
         configuration=configuration,
         create_protocol=create_protocol,
@@ -217,7 +218,7 @@ async def async_h3_query(
             return (e,{})
 
 
-def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
+def doh3_query(query, host, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
     configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname)
     if verify:
         configuration.load_verify_locations(verify)
@@ -225,6 +226,7 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname
     (result, headers) = asyncio.run(
         async_h3_query(
             configuration=configuration,
+            host=host,
             baseurl=baseurl,
             port=port,
             query=query,