From: Karel Bilek Date: Tue, 9 Jun 2026 10:17:24 +0000 (+0200) Subject: dnsdist: don't close DoH on timeout, do on release X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=441bc7ffe8be11ab180ecca56af012ccacb6091c;p=thirdparty%2Fpdns.git dnsdist: don't close DoH on timeout, do on release Currently, dnsdist closes outgoing DoH on receive timeout, instead of keeping TCP connection alive. On the other hand, when DoH max idle time is reached, the connection is not closed properly and the filedescriptor is not properly closed. The second issue was hard to reach before, as cleanup interval would need to be lower than TCP read timeout. There is a new testcase for this, which needed slight refactor of test DoH threads. Signed-off-by: Karel Bilek --- diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 115fd10e61..28e7d5380e 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -79,6 +79,10 @@ public: void release(bool removeFromCache) override { (void)removeFromCache; + if (d_ioState) { + d_ioState.reset(); + } + nghttp2_session_terminate_session(d_session.get(), NGHTTP2_NO_ERROR); } private: @@ -546,7 +550,7 @@ void DoHConnectionToBackend::updateIO(IOState newState, const FDMultiplexer::cal void DoHConnectionToBackend::watchForRemoteHostClosingConnection() { if (willBeReusable(false) && !d_healthCheckQuery) { - updateIO(IOState::NeedRead, handleReadableIOCallback, false); + updateIO(IOState::NeedRead, handleReadableIOCallback, true); } } diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 55f0142640..cd86a0a965 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -567,6 +567,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): callback, tlsContext, useProxyProtocol, + closeConnCallback, ): ignoreTrailing = trailingDataResponse is True try: @@ -588,18 +589,24 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): header = conn.recv(proxy.HEADER_SIZE) if not header: print("unable to get header") + if closeConnCallback: + closeConnCallback(conn) conn.close() return if not proxy.parseHeader(header): print("unable to parse header") print(header) + if closeConnCallback: + closeConnCallback(conn) conn.close() return proxyContent = conn.recv(proxy.contentLen) if not proxyContent: print("unable to get content") + if closeConnCallback: + closeConnCallback(conn) conn.close() return @@ -619,6 +626,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): events = h2conn.receive_data(data) for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + break if isinstance(event, h2.events.RequestReceived): requestHeaders = event.headers if isinstance(event, h2.events.DataReceived): @@ -639,13 +648,15 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): forceRcode = trailingDataResponse if callback: - status, wire = callback(request, requestHeaders, fromQueue, toQueue) + status, wire = callback(request, requestHeaders, fromQueue, toQueue, conn) else: response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) if response: wire = response.to_wire(max_size=65535) if not wire: + if closeConnCallback: + closeConnCallback(conn) conn.close() conn = None break @@ -668,6 +679,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): break if conn is not None: + if closeConnCallback: + closeConnCallback(conn) conn.close() @classmethod @@ -681,6 +694,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): callback=None, tlsContext=None, useProxyProtocol=False, + closeConnCallback=False, + connTimeout=5.0, ): cls._backgroundThreads[threading.get_native_id()] = True # trailingDataResponse=True means "ignore trailing data". @@ -718,7 +733,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): else: continue - conn.settimeout(5.0) + conn.settimeout(connTimeout) thread = threading.Thread( name="DoH Connection Handler", target=cls.handleDoHConnection, @@ -732,6 +747,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): callback, tlsContext, useProxyProtocol, + closeConnCallback, ], ) thread.daemon = True diff --git a/regression-tests.dnsdist/test_HealthChecks.py b/regression-tests.dnsdist/test_HealthChecks.py index 7d5528925d..d4dd61f45e 100644 --- a/regression-tests.dnsdist/test_HealthChecks.py +++ b/regression-tests.dnsdist/test_HealthChecks.py @@ -391,7 +391,7 @@ class TestLazyHealthChecks(HealthCheckTest): return cls.HandleDNSQuery(request) @classmethod - def DoHCallback(cls, request, requestHeaders, fromQueue, toQueue): + def DoHCallback(cls, request, requestHeaders, fromQueue, toQueue, conn): global _dohHealthCheckQueries if str(request.question[0].name).startswith("a.root-servers.net"): _dohHealthCheckQueries = _dohHealthCheckQueries + 1 diff --git a/regression-tests.dnsdist/test_OutgoingDOH.py b/regression-tests.dnsdist/test_OutgoingDOH.py index 8413773ed8..10e80da63f 100644 --- a/regression-tests.dnsdist/test_OutgoingDOH.py +++ b/regression-tests.dnsdist/test_OutgoingDOH.py @@ -617,7 +617,7 @@ class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenRespon addAction(SuffixMatchNodeRule(smn), PoolAction('cache')) """ - def callback(request, headers, fromQueue, toQueue): + def callback(request, headers, fromQueue, toQueue, conn): if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.": print("returning 500") @@ -672,7 +672,7 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons """ _verboseMode = True - def callback(request, headers, fromQueue, toQueue): + def callback(request, headers, fromQueue, toQueue, conn): if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.": print("returning 500") @@ -776,8 +776,7 @@ class TestOutgoingDOHXForwarded(DNSDistTest): """ _verboseMode = True - def callback(request, headersList, fromQueue, toQueue): - + def callback(request, headersList, fromQueue, toQueue, conn): if str(request.question[0].name) == "a.root-servers.net.": # do not check headers on health-check queries return 200, dns.message.make_response(request).to_wire() @@ -845,3 +844,89 @@ class TestOutgoingDOHXForwarded(DNSDistTest): (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse) self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) + + +class TestOutgoingDOHKeepsConnection(DNSDistTest): + _tlsBackendPort = pickAvailablePort() + _config_params = ["_tlsBackendPort"] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%d", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', tcpRecvTimeout=2} + setDoHDownstreamMaxIdleTime(8) + setDoHDownstreamCleanupInterval(4) + """ + _verboseMode = True + + _upconns = [] + _lock = threading.Lock() + + @classmethod + def callback(cls, request, headersList, fromQueue, toQueue, conn): + # we want to track up connections on the backend, but not the healthchecks + # however we cannot tell connections apart until they send a request + if str(request.question[0].name) == "a.root-servers.net.": + return 200, dns.message.make_response(request).to_wire() + + with cls._lock: + if conn not in cls._upconns: + cls._upconns.append(conn) + + return 200, dns.message.make_response(request).to_wire() + + @classmethod + def closeConnCallback(cls, conn): + with cls._lock: + if conn in cls._upconns: + cls._upconns.remove(conn) + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain("server.chain", "server.key") + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread( + name="DOH Responder", + target=cls.DOHResponder, + args=[ + cls._tlsBackendPort, + cls._toResponderQueue, + cls._fromResponderQueue, + False, + False, + cls.callback, + tlsContext, + False, + cls.closeConnCallback, + 35.0, + ], + ) + cls._DOHResponder.daemon = True + cls._DOHResponder.start() + + def testKeepsAndClosesConnection(self): + """ + Outgoing DOH: keeps and closes idle DOH connections properly + """ + name = "anything.powerdns.com." + query = dns.message.make_query(name, "A", "IN") + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, "127.0.0.1") + expectedResponse.answer.append(rrset) + + self.sendUDPQuery(query, expectedResponse) + time.sleep(3) + with self._lock: + up = len(self._upconns) + + self.assertEqual(up, 1) + + for _ in range(10): + with self._lock: + up = len(self._upconns) + if up == 0: + break + time.sleep(1) + + self.assertEqual(up, 0) diff --git a/regression-tests.dnsdist/test_TimeoutResponse.py b/regression-tests.dnsdist/test_TimeoutResponse.py index 4419efd18b..9950e4d4a0 100644 --- a/regression-tests.dnsdist/test_TimeoutResponse.py +++ b/regression-tests.dnsdist/test_TimeoutResponse.py @@ -39,11 +39,11 @@ def normalResponseCallback(request): return response.to_wire() -def dohTimeoutResponseCallback(request, headers, fromQueue, toQueue): +def dohTimeoutResponseCallback(request, headers, fromQueue, toQueue, conn): return 200, timeoutResponseCallback(request) -def dohNormalResponseCallback(request, headers, fromQueue, toQueue): +def dohNormalResponseCallback(request, headers, fromQueue, toQueue, conn): return 200, normalResponseCallback(request)