]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: don't close DoH on timeout, do on release 17549/head
authorKarel Bilek <kb@karelbilek.com>
Tue, 9 Jun 2026 10:17:24 +0000 (12:17 +0200)
committerKarel Bilek <kb@karelbilek.com>
Wed, 10 Jun 2026 09:05:23 +0000 (11:05 +0200)
Currently, dnsdist closes outgoing DoH on receive
timeout, instead of keeping TCP connection alive.

On the other hand, when DoH max idle time is
reached, the connection is not closed properly and
the filedescriptor is not properly closed.

The second issue was hard to reach before, as
cleanup interval would need to be lower than TCP
read timeout.

There is a new testcase for this, which needed
slight refactor of test DoH threads.

Signed-off-by: Karel Bilek <kb@karelbilek.com>
pdns/dnsdistdist/dnsdist-nghttp2.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_HealthChecks.py
regression-tests.dnsdist/test_OutgoingDOH.py
regression-tests.dnsdist/test_TimeoutResponse.py

index 115fd10e6105853d11edf3f17ed7b7caab30c755..28e7d5380e0eb8731de51db37b3155e8effa8958 100644 (file)
@@ -79,6 +79,10 @@ public:
   void release(bool removeFromCache) override
   {
     (void)removeFromCache;
+    if (d_ioState) {
+      d_ioState.reset();
+    }
+    nghttp2_session_terminate_session(d_session.get(), NGHTTP2_NO_ERROR);
   }
 
 private:
@@ -546,7 +550,7 @@ void DoHConnectionToBackend::updateIO(IOState newState, const FDMultiplexer::cal
 void DoHConnectionToBackend::watchForRemoteHostClosingConnection()
 {
   if (willBeReusable(false) && !d_healthCheckQuery) {
-    updateIO(IOState::NeedRead, handleReadableIOCallback, false);
+    updateIO(IOState::NeedRead, handleReadableIOCallback, true);
   }
 }
 
index 55f0142640bd500822866df46b7d560596d701f1..cd86a0a96535d2c535930134751d5d483763f7b0 100644 (file)
@@ -567,6 +567,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         callback,
         tlsContext,
         useProxyProtocol,
+        closeConnCallback,
     ):
         ignoreTrailing = trailingDataResponse is True
         try:
@@ -588,18 +589,24 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             header = conn.recv(proxy.HEADER_SIZE)
             if not header:
                 print("unable to get header")
+                if closeConnCallback:
+                    closeConnCallback(conn)
                 conn.close()
                 return
 
             if not proxy.parseHeader(header):
                 print("unable to parse header")
                 print(header)
+                if closeConnCallback:
+                    closeConnCallback(conn)
                 conn.close()
                 return
 
             proxyContent = conn.recv(proxy.contentLen)
             if not proxyContent:
                 print("unable to get content")
+                if closeConnCallback:
+                    closeConnCallback(conn)
                 conn.close()
                 return
 
@@ -619,6 +626,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
             events = h2conn.receive_data(data)
             for event in events:
+                if isinstance(event, h2.events.ConnectionTerminated):
+                    break
                 if isinstance(event, h2.events.RequestReceived):
                     requestHeaders = event.headers
                 if isinstance(event, h2.events.DataReceived):
@@ -639,13 +648,15 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                             forceRcode = trailingDataResponse
 
                         if callback:
-                            status, wire = callback(request, requestHeaders, fromQueue, toQueue)
+                            status, wire = callback(request, requestHeaders, fromQueue, toQueue, conn)
                         else:
                             response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
                             if response:
                                 wire = response.to_wire(max_size=65535)
 
                         if not wire:
+                            if closeConnCallback:
+                                closeConnCallback(conn)
                             conn.close()
                             conn = None
                             break
@@ -668,6 +679,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 break
 
         if conn is not None:
+            if closeConnCallback:
+                closeConnCallback(conn)
             conn.close()
 
     @classmethod
@@ -681,6 +694,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         callback=None,
         tlsContext=None,
         useProxyProtocol=False,
+        closeConnCallback=False,
+        connTimeout=5.0,
     ):
         cls._backgroundThreads[threading.get_native_id()] = True
         # trailingDataResponse=True means "ignore trailing data".
@@ -718,7 +733,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 else:
                     continue
 
-            conn.settimeout(5.0)
+            conn.settimeout(connTimeout)
             thread = threading.Thread(
                 name="DoH Connection Handler",
                 target=cls.handleDoHConnection,
@@ -732,6 +747,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                     callback,
                     tlsContext,
                     useProxyProtocol,
+                    closeConnCallback,
                 ],
             )
             thread.daemon = True
index 7d5528925d33200c2dd3c494ee76508a1b348436..d4dd61f45ee5dbb1f949482fb85dabe3940001f1 100644 (file)
@@ -391,7 +391,7 @@ class TestLazyHealthChecks(HealthCheckTest):
         return cls.HandleDNSQuery(request)
 
     @classmethod
-    def DoHCallback(cls, request, requestHeaders, fromQueue, toQueue):
+    def DoHCallback(cls, request, requestHeaders, fromQueue, toQueue, conn):
         global _dohHealthCheckQueries
         if str(request.question[0].name).startswith("a.root-servers.net"):
             _dohHealthCheckQueries = _dohHealthCheckQueries + 1
index 8413773ed897a74a898fe56849be9b6e0b69359f..10e80da63f0d28a71d8f65e47b87b47e824f9931 100644 (file)
@@ -617,7 +617,7 @@ class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenRespon
     addAction(SuffixMatchNodeRule(smn), PoolAction('cache'))
     """
 
-    def callback(request, headers, fromQueue, toQueue):
+    def callback(request, headers, fromQueue, toQueue, conn):
 
         if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.":
             print("returning 500")
@@ -672,7 +672,7 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons
     """
     _verboseMode = True
 
-    def callback(request, headers, fromQueue, toQueue):
+    def callback(request, headers, fromQueue, toQueue, conn):
 
         if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.":
             print("returning 500")
@@ -776,8 +776,7 @@ class TestOutgoingDOHXForwarded(DNSDistTest):
     """
     _verboseMode = True
 
-    def callback(request, headersList, fromQueue, toQueue):
-
+    def callback(request, headersList, fromQueue, toQueue, conn):
         if str(request.question[0].name) == "a.root-servers.net.":
             # do not check headers on health-check queries
             return 200, dns.message.make_response(request).to_wire()
@@ -845,3 +844,89 @@ class TestOutgoingDOHXForwarded(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, expectedResponse)
+
+
+class TestOutgoingDOHKeepsConnection(DNSDistTest):
+    _tlsBackendPort = pickAvailablePort()
+    _config_params = ["_tlsBackendPort"]
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%d", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', tcpRecvTimeout=2}
+    setDoHDownstreamMaxIdleTime(8)
+    setDoHDownstreamCleanupInterval(4)
+    """
+    _verboseMode = True
+
+    _upconns = []
+    _lock = threading.Lock()
+
+    @classmethod
+    def callback(cls, request, headersList, fromQueue, toQueue, conn):
+        # we want to track up connections on the backend, but not the healthchecks
+        # however we cannot tell connections apart until they send a request
+        if str(request.question[0].name) == "a.root-servers.net.":
+            return 200, dns.message.make_response(request).to_wire()
+
+        with cls._lock:
+            if conn not in cls._upconns:
+                cls._upconns.append(conn)
+
+        return 200, dns.message.make_response(request).to_wire()
+
+    @classmethod
+    def closeConnCallback(cls, conn):
+        with cls._lock:
+            if conn in cls._upconns:
+                cls._upconns.remove(conn)
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain("server.chain", "server.key")
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(
+            name="DOH Responder",
+            target=cls.DOHResponder,
+            args=[
+                cls._tlsBackendPort,
+                cls._toResponderQueue,
+                cls._fromResponderQueue,
+                False,
+                False,
+                cls.callback,
+                tlsContext,
+                False,
+                cls.closeConnCallback,
+                35.0,
+            ],
+        )
+        cls._DOHResponder.daemon = True
+        cls._DOHResponder.start()
+
+    def testKeepsAndClosesConnection(self):
+        """
+        Outgoing DOH: keeps and closes idle DOH connections properly
+        """
+        name = "anything.powerdns.com."
+        query = dns.message.make_query(name, "A", "IN")
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, "127.0.0.1")
+        expectedResponse.answer.append(rrset)
+
+        self.sendUDPQuery(query, expectedResponse)
+        time.sleep(3)
+        with self._lock:
+            up = len(self._upconns)
+
+        self.assertEqual(up, 1)
+
+        for _ in range(10):
+            with self._lock:
+                up = len(self._upconns)
+            if up == 0:
+                break
+            time.sleep(1)
+
+        self.assertEqual(up, 0)
index 4419efd18be1de19e1f4281655ea1ef038ebb15f..9950e4d4a00ad259824aa3784f8a9f7a89d3e483 100644 (file)
@@ -39,11 +39,11 @@ def normalResponseCallback(request):
     return response.to_wire()
 
 
-def dohTimeoutResponseCallback(request, headers, fromQueue, toQueue):
+def dohTimeoutResponseCallback(request, headers, fromQueue, toQueue, conn):
     return 200, timeoutResponseCallback(request)
 
 
-def dohNormalResponseCallback(request, headers, fromQueue, toQueue):
+def dohNormalResponseCallback(request, headers, fromQueue, toQueue, conn):
     return 200, normalResponseCallback(request)