callback,
tlsContext,
useProxyProtocol,
+ closeConnCallback,
):
ignoreTrailing = trailingDataResponse is True
try:
header = conn.recv(proxy.HEADER_SIZE)
if not header:
print("unable to get header")
+ if closeConnCallback:
+ closeConnCallback(conn)
conn.close()
return
if not proxy.parseHeader(header):
print("unable to parse header")
print(header)
+ if closeConnCallback:
+ closeConnCallback(conn)
conn.close()
return
proxyContent = conn.recv(proxy.contentLen)
if not proxyContent:
print("unable to get content")
+ if closeConnCallback:
+ closeConnCallback(conn)
conn.close()
return
events = h2conn.receive_data(data)
for event in events:
+ if isinstance(event, h2.events.ConnectionTerminated):
+ break
if isinstance(event, h2.events.RequestReceived):
requestHeaders = event.headers
if isinstance(event, h2.events.DataReceived):
forceRcode = trailingDataResponse
if callback:
- status, wire = callback(request, requestHeaders, fromQueue, toQueue)
+ status, wire = callback(request, requestHeaders, fromQueue, toQueue, conn)
else:
response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if response:
wire = response.to_wire(max_size=65535)
if not wire:
+ if closeConnCallback:
+ closeConnCallback(conn)
conn.close()
conn = None
break
break
if conn is not None:
+ if closeConnCallback:
+ closeConnCallback(conn)
conn.close()
@classmethod
callback=None,
tlsContext=None,
useProxyProtocol=False,
+ closeConnCallback=False,
+ connTimeout=5.0,
):
cls._backgroundThreads[threading.get_native_id()] = True
# trailingDataResponse=True means "ignore trailing data".
else:
continue
- conn.settimeout(5.0)
+ conn.settimeout(connTimeout)
thread = threading.Thread(
name="DoH Connection Handler",
target=cls.handleDoHConnection,
callback,
tlsContext,
useProxyProtocol,
+ closeConnCallback,
],
)
thread.daemon = True
addAction(SuffixMatchNodeRule(smn), PoolAction('cache'))
"""
- def callback(request, headers, fromQueue, toQueue):
+ def callback(request, headers, fromQueue, toQueue, conn):
if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.":
print("returning 500")
"""
_verboseMode = True
- def callback(request, headers, fromQueue, toQueue):
+ def callback(request, headers, fromQueue, toQueue, conn):
if str(request.question[0].name) == "500-status.broken-responses.outgoing-doh.test.powerdns.com.":
print("returning 500")
"""
_verboseMode = True
- def callback(request, headersList, fromQueue, toQueue):
-
+ def callback(request, headersList, fromQueue, toQueue, conn):
if str(request.question[0].name) == "a.root-servers.net.":
# do not check headers on health-check queries
return 200, dns.message.make_response(request).to_wire()
(receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
self.assertEqual(query, receivedQuery)
self.assertEqual(receivedResponse, expectedResponse)
+
+
+class TestOutgoingDOHKeepsConnection(DNSDistTest):
+ _tlsBackendPort = pickAvailablePort()
+ _config_params = ["_tlsBackendPort"]
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%d", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', tcpRecvTimeout=2}
+ setDoHDownstreamMaxIdleTime(8)
+ setDoHDownstreamCleanupInterval(4)
+ """
+ _verboseMode = True
+
+ _upconns = []
+ _lock = threading.Lock()
+
+ @classmethod
+ def callback(cls, request, headersList, fromQueue, toQueue, conn):
+ # we want to track up connections on the backend, but not the healthchecks
+ # however we cannot tell connections apart until they send a request
+ if str(request.question[0].name) == "a.root-servers.net.":
+ return 200, dns.message.make_response(request).to_wire()
+
+ with cls._lock:
+ if conn not in cls._upconns:
+ cls._upconns.append(conn)
+
+ return 200, dns.message.make_response(request).to_wire()
+
+ @classmethod
+ def closeConnCallback(cls, conn):
+ with cls._lock:
+ if conn in cls._upconns:
+ cls._upconns.remove(conn)
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.set_alpn_protocols(["h2"])
+ tlsContext.load_cert_chain("server.chain", "server.key")
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(
+ name="DOH Responder",
+ target=cls.DOHResponder,
+ args=[
+ cls._tlsBackendPort,
+ cls._toResponderQueue,
+ cls._fromResponderQueue,
+ False,
+ False,
+ cls.callback,
+ tlsContext,
+ False,
+ cls.closeConnCallback,
+ 35.0,
+ ],
+ )
+ cls._DOHResponder.daemon = True
+ cls._DOHResponder.start()
+
+ def testKeepsAndClosesConnection(self):
+ """
+ Outgoing DOH: keeps and closes idle DOH connections properly
+ """
+ name = "anything.powerdns.com."
+ query = dns.message.make_query(name, "A", "IN")
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name, 60, dns.rdataclass.IN, dns.rdatatype.A, "127.0.0.1")
+ expectedResponse.answer.append(rrset)
+
+ self.sendUDPQuery(query, expectedResponse)
+ time.sleep(3)
+ with self._lock:
+ up = len(self._upconns)
+
+ self.assertEqual(up, 1)
+
+ for _ in range(10):
+ with self._lock:
+ up = len(self._upconns)
+ if up == 0:
+ break
+ time.sleep(1)
+
+ self.assertEqual(up, 0)