]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Implement async processing of queries and responses
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 16 Dec 2022 17:31:33 +0000 (18:31 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 13 Jan 2023 15:57:47 +0000 (16:57 +0100)
26 files changed:
pdns/dnsdist-idstate.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-async.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-async.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-healthchecks.cc
pdns/dnsdistdist/dnsdist-internal-queries.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-internal-queries.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-bindings-network.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-lua-network.cc
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/test-dnsdistasync.cc [new file with mode: 0644]
pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/doh.hh
pdns/lock.hh
pdns/test-dnsdist_cc.cc

index 5b42f986f561f764767b9835298904f816ec441a..cec3a329ab470d341c7c5126eda1c6038423200a 100644 (file)
@@ -152,6 +152,7 @@ struct InternalQueryState
   bool dnssecOK{false};
   bool useZeroScope{false};
   bool forwardedOverUDP{false};
+  bool selfGenerated{false};
 };
 
 struct IDState
index 7e5ecb8df6eb6a4d72b016e2d57e9ef0585659e9..e0839bcd135c28daff9937c2c2050f5627aad8d9 100644 (file)
@@ -22,6 +22,7 @@
 #include "config.h"
 #include "threadname.hh"
 #include "dnsdist.hh"
+#include "dnsdist-async.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-lua.hh"
 #include "dnsdist-lua-ffi.hh"
@@ -492,19 +493,24 @@ public:
 
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    auto lock = g_lua.lock();
     try {
-      auto ret = d_func(dq);
-      if (ruleresult) {
-        if (boost::optional<std::string> rule = std::get<1>(ret)) {
-          *ruleresult = *rule;
-        }
-        else {
-          // default to empty string
-          ruleresult->clear();
+      DNSAction::Action result;
+      {
+        auto lock = g_lua.lock();
+        auto ret = d_func(dq);
+        if (ruleresult) {
+          if (boost::optional<std::string> rule = std::get<1>(ret)) {
+            *ruleresult = *rule;
+          }
+          else {
+            // default to empty string
+            ruleresult->clear();
+          }
         }
+        result = static_cast<Action>(std::get<0>(ret));
       }
-      return static_cast<Action>(std::get<0>(ret));
+      dnsdist::handleQueuedAsynchronousEvents();
+      return result;
     } catch (const std::exception &e) {
       warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what());
     } catch (...) {
@@ -529,19 +535,24 @@ public:
   {}
   DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
   {
-    auto lock = g_lua.lock();
     try {
-      auto ret = d_func(dr);
-      if (ruleresult) {
-        if (boost::optional<std::string> rule = std::get<1>(ret)) {
-          *ruleresult = *rule;
-        }
-        else {
-          // default to empty string
-          ruleresult->clear();
+      DNSResponseAction::Action result;
+      {
+        auto lock = g_lua.lock();
+        auto ret = d_func(dr);
+        if (ruleresult) {
+          if (boost::optional<std::string> rule = std::get<1>(ret)) {
+            *ruleresult = *rule;
+          }
+          else {
+            // default to empty string
+            ruleresult->clear();
+          }
         }
+        result = static_cast<Action>(std::get<0>(ret));
       }
-      return static_cast<Action>(std::get<0>(ret));
+      dnsdist::handleQueuedAsynchronousEvents();
+      return result;
     } catch (const std::exception &e) {
       warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what());
     } catch (...) {
@@ -571,18 +582,23 @@ public:
   {
     dnsdist_ffi_dnsquestion_t dqffi(dq);
     try {
-      auto lock = g_lua.lock();
-      auto ret = d_func(&dqffi);
-      if (ruleresult) {
-        if (dqffi.result) {
-          *ruleresult = *dqffi.result;
-        }
-        else {
-          // default to empty string
-          ruleresult->clear();
+      DNSAction::Action result;
+      {
+        auto lock = g_lua.lock();
+        auto ret = d_func(&dqffi);
+        if (ruleresult) {
+          if (dqffi.result) {
+            *ruleresult = *dqffi.result;
+          }
+          else {
+            // default to empty string
+            ruleresult->clear();
+          }
         }
+        result = static_cast<DNSAction::Action>(ret);
       }
-      return static_cast<DNSAction::Action>(ret);
+      dnsdist::handleQueuedAsynchronousEvents();
+      return result;
     } catch (const std::exception &e) {
       warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what());
     } catch (...) {
@@ -636,6 +652,7 @@ public:
           ruleresult->clear();
         }
       }
+      dnsdist::handleQueuedAsynchronousEvents();
       return static_cast<DNSAction::Action>(ret);
     }
     catch (const std::exception &e) {
@@ -681,18 +698,23 @@ public:
   {
     dnsdist_ffi_dnsresponse_t drffi(dr);
     try {
-      auto lock = g_lua.lock();
-      auto ret = d_func(&drffi);
-      if (ruleresult) {
-        if (drffi.result) {
-          *ruleresult = *drffi.result;
-        }
-        else {
-          // default to empty string
-          ruleresult->clear();
+      DNSResponseAction::Action result;
+      {
+        auto lock = g_lua.lock();
+        auto ret = d_func(&drffi);
+        if (ruleresult) {
+          if (drffi.result) {
+            *ruleresult = *drffi.result;
+          }
+          else {
+            // default to empty string
+            ruleresult->clear();
+          }
         }
+        result = static_cast<DNSResponseAction::Action>(ret);
       }
-      return static_cast<DNSResponseAction::Action>(ret);
+      dnsdist::handleQueuedAsynchronousEvents();
+      return result;
     } catch (const std::exception &e) {
       warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what());
     } catch (...) {
@@ -746,6 +768,7 @@ public:
           ruleresult->clear();
         }
       }
+      dnsdist::handleQueuedAsynchronousEvents();
       return static_cast<DNSResponseAction::Action>(ret);
     }
     catch (const std::exception &e) {
index fbb04c2d82dda1d91e82b109b2053e29a704036b..2677168f4420abadf41fc2f3e6c4e3e402cf8c88 100644 (file)
@@ -246,8 +246,8 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
 
   --state->d_currentQueriesCount;
 
-  if (currentResponse.d_selfGenerated == false && currentResponse.d_connection && currentResponse.d_connection->getDS()) {
-    const auto& ds = currentResponse.d_connection->getDS();
+  const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
+  if (currentResponse.d_idstate.selfGenerated == false && ds) {
     const auto& ids = currentResponse.d_idstate;
     double udiff = ids.queryRealTime.udiff();
     vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff);
@@ -498,7 +498,7 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
 
   std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
 
-  if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
+  if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
     // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway
     if (!response.d_connection->willBeReusable(true)) {
       // if it can't be reused even by us, well
@@ -527,32 +527,40 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
     return;
   }
 
-  try {
-    auto& ids = response.d_idstate;
-    unsigned int qnameWireLength;
-    if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) {
-      state->terminateClientConnection();
-      return;
-    }
+  if (!response.isAsync()) {
+    try {
+      auto& ids = response.d_idstate;
+      unsigned int qnameWireLength;
+      if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) {
+        state->terminateClientConnection();
+        return;
+      }
 
-    if (response.d_connection->getDS()) {
-      ++response.d_connection->getDS()->responses;
-    }
+      if (response.d_connection->getDS()) {
+        ++response.d_connection->getDS()->responses;
+      }
 
-    DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS());
+      DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS());
+      dr.d_incomingTCPState = state;
 
-    memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
+      memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
 
-    if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) {
+      if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) {
+        state->terminateClientConnection();
+        return;
+      }
+
+      if (dr.isAsynchronous()) {
+        /* we are done for now */
+        return;
+      }
+    }
+    catch (const std::exception& e) {
+      vinfolog("Unexpected exception while handling response from backend: %s", e.what());
       state->terminateClientConnection();
       return;
     }
   }
-  catch (const std::exception& e) {
-    vinfolog("Unexpected exception while handling response from backend: %s", e.what());
-    state->terminateClientConnection();
-    return;
-  }
 
   ++g_stats.responses;
   ++state->d_ci.cs->responses;
@@ -574,7 +582,7 @@ struct TCPCrossProtocolResponse
 class TCPCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
+  TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
   {
     proxyProtocolPayloadSize = 0;
   }
@@ -588,10 +596,37 @@ public:
     return d_sender;
   }
 
+  DNSQuestion getDQ() override
+  {
+    auto& ids = query.d_idstate;
+    DNSQuestion dq(ids, query.d_buffer);
+    dq.d_incomingTCPState = d_sender;
+    return dq;
+  }
+
+  DNSResponse getDR() override
+  {
+    auto& ids = query.d_idstate;
+    DNSResponse dr(ids, query.d_buffer, downstream);
+    dr.d_incomingTCPState = d_sender;
+    return dr;
+  }
+
 private:
   std::shared_ptr<IncomingTCPConnectionState> d_sender;
 };
 
+std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq)
+{
+  auto state = dq.getIncomingTCPState();
+  if (!state) {
+    throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state");
+  }
+
+  dq.ids.origID = dq.getHeader()->id;
+  return std::make_unique<TCPCrossProtocolQuery>(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state));
+}
+
 void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
 {
   if (d_threadData.crossProtocolResponsesPipe == -1) {
@@ -674,7 +709,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
       TCPResponse response;
       dh->rcode = RCode::NotImp;
       dh->qr = true;
-      response.d_selfGenerated = true;
+      response.d_idstate.selfGenerated = true;
       response.d_buffer = std::move(state->d_buffer);
       state->d_state = IncomingTCPConnectionState::State::idle;
       ++state->d_currentQueriesCount;
@@ -695,8 +730,9 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   DNSQuestion dq(ids, state->d_buffer);
   const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
   ids.origFlags = *flags;
-
+  dq.d_incomingTCPState = state;
   dq.sni = state->d_handler.getServerNameIndication();
+
   if (state->d_proxyProtocolValues) {
     /* we need to copy them, because the next queries received on that connection will
        need to get the _unaltered_ values */
@@ -708,12 +744,17 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   }
 
   std::shared_ptr<DownstreamState> ds;
-  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds);
+  auto result = processQuery(dq, state->d_threadData.holders, ds);
 
   if (result == ProcessQueryResult::Drop) {
     state->terminateClientConnection();
     return;
   }
+  else if (result == ProcessQueryResult::Asynchronous) {
+    /* we are done for now */
+    ++state->d_currentQueriesCount;
+    return;
+  }
 
   // the buffer might have been invalidated by now
   const dnsheader* dh = dq.getHeader();
@@ -722,6 +763,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
     memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
     response.d_idstate = std::move(ids);
     response.d_idstate.origID = dh->id;
+    response.d_idstate.selfGenerated = true;
     response.d_idstate.cs = state->d_ci.cs;
     response.d_buffer = std::move(state->d_buffer);
 
@@ -1399,6 +1441,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
     }
 
     if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) {
+      vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort());
       return;
     }
 
index a1a6061eae23d099463408928836636a63f7439c..ccf32fc8aae07ff5d1c7789666c7f11be570306e 100644 (file)
@@ -48,6 +48,7 @@
 #endif
 
 #include "dnsdist.hh"
+#include "dnsdist-async.hh"
 #include "dnsdist-cache.hh"
 #include "dnsdist-carbon.hh"
 #include "dnsdist-console.hh"
@@ -513,18 +514,14 @@ static bool applyRulesToResponse(const std::vector<DNSDistResponseRuleAction>& r
   return true;
 }
 
-bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& insertedRespRuleActions, DNSResponse& dr, bool muted)
+bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted)
 {
-  if (!applyRulesToResponse(respRuleActions, dr)) {
-    return false;
-  }
-
   bool zeroScope = false;
   if (!fixUpResponse(response, dr.ids.qname, dr.ids.origFlags, dr.ids.ednsAdded, dr.ids.ecsAdded, dr.ids.useZeroScope ? &zeroScope : nullptr)) {
     return false;
   }
 
-  if (dr.ids.packetCache && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) {
+  if (dr.ids.packetCache && !dr.ids.selfGenerated && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) {
     if (!dr.ids.useZeroScope) {
       /* if the query was not suitable for zero-scope, for
          example because it had an existing ECS entry so the hash is
@@ -547,7 +544,7 @@ bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRu
 
     dr.ids.packetCache->insert(cacheKey, zeroScope ? boost::none : dr.ids.subnet, dr.ids.cacheFlags, dr.ids.dnssecOK, dr.ids.qname, dr.ids.qtype, dr.ids.qclass, response, dr.ids.forwardedOverUDP, dr.getHeader()->rcode, dr.ids.tempFailureTTL);
 
-    if (!applyRulesToResponse(insertedRespRuleActions, dr)) {
+    if (!applyRulesToResponse(cacheInsertedRespRuleActions, dr)) {
       return false;
     }
   }
@@ -569,6 +566,19 @@ bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRu
   return true;
 }
 
+bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted)
+{
+  if (!applyRulesToResponse(respRuleActions, dr)) {
+    return false;
+  }
+
+  if (dr.isAsynchronous()) {
+    return true;
+  }
+
+  return processResponseAfterRules(response, cacheInsertedRespRuleActions, dr, muted);
+}
+
 static size_t getInitialUDPPacketBufferSize()
 {
   static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "The incoming buffer size should not be larger than s_initialUDPPacketBufferSize");
@@ -593,7 +603,7 @@ static size_t getMaximumIncomingPacketSize(const ClientState& cs)
   return s_udpIncomingBufferSize + g_proxyProtocolMaximumSize;
 }
 
-static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
 #ifndef DISABLE_DELAY_PIPE
   if (delayMsec && g_delay) {
@@ -640,7 +650,7 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff,
   doLatencyStats(incomingProtocol, udiff);
 }
 
-static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated)
+static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool isAsync, bool selfGenerated)
 {
   DNSResponse dr(ids, response, ds);
 
@@ -658,8 +668,14 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   dnsheader cleartextDH;
   memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
-  if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
-    return;
+  if (!isAsync) {
+    if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
+      return;
+    }
+
+    if (dr.isAsynchronous()) {
+      return;
+    }
   }
 
   ++g_stats.responses;
@@ -757,7 +773,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           continue;
         }
 
-        handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false);
+        handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false);
       }
     }
     catch (const std::exception& e) {
@@ -830,6 +846,10 @@ static void spoofPacketFromString(DNSQuestion& dq, const string& spoofContent)
 
 bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop)
 {
+  if (dq.isAsynchronous()) {
+    return false;
+  }
+
   switch(action) {
   case DNSAction::Action::Allow:
     return true;
@@ -1215,10 +1235,12 @@ static void queueResponse(const ClientState& cs, const PacketBuffer& response, c
 #endif /* DISABLE_RECVMMSG */
 
 /* self-generated responses or cache hits */
-static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
+static bool prepareOutgoingResponse(LocalHolders& holders, const ClientState& cs, DNSQuestion& dq, bool cacheHit)
 {
   std::shared_ptr<DownstreamState> ds{nullptr};
   DNSResponse dr(dq.ids, dq.getMutableData(), ds);
+  dr.d_incomingTCPState = dq.d_incomingTCPState;
+  dr.ids.selfGenerated = true;
 
   if (!applyRulesToResponse(cacheHit ? *holders.cacheHitRespRuleactions : *holders.selfAnsweredRespRuleactions, dr)) {
     return false;
@@ -1230,6 +1252,14 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ
     ac(&dr, &result);
   }
 
+  if (cacheHit) {
+    ++g_stats.cacheHits;
+  }
+
+  if (dr.isAsynchronous()) {
+    return false;
+  }
+
 #ifdef HAVE_DNSCRYPT
   if (!cs.muted) {
     if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.overTCP(), dq.ids.dnsCryptQuery)) {
@@ -1238,28 +1268,14 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ
   }
 #endif /* HAVE_DNSCRYPT */
 
-  if (cacheHit) {
-    ++g_stats.cacheHits;
-  }
-
   return true;
 }
 
-ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
+ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
   const uint16_t queryId = ntohs(dq.getHeader()->id);
 
   try {
-    /* we need an accurate ("real") value for the response and
-       to store into the IDS, but not for insertion into the
-       rings for example */
-    struct timespec now;
-    gettime(&now);
-
-    if (!applyRulesToQuery(holders, dq, now)) {
-      return ProcessQueryResult::Drop;
-    }
-
     if (dq.getHeader()->qr) { // something turned it into a response
       fixUpQueryTurnedResponse(dq, dq.ids.origFlags);
 
@@ -1329,7 +1345,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
 
         vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size());
 
-        if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+        if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) {
           return ProcessQueryResult::Drop;
         }
 
@@ -1342,7 +1358,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       else if (dq.ids.protocol == dnsdist::Protocol::DoH && !forwardedOverUDP) {
         /* do a second-lookup for UDP responses, but we do not want TC=1 answers */
         if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyUDP, dq.ids.subnet, dq.ids.dnssecOK, true, allowExpired, false, false, false)) {
-          if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+          if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) {
             return ProcessQueryResult::Drop;
           }
 
@@ -1369,7 +1385,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
 
         fixUpQueryTurnedResponse(dq, dq.ids.origFlags);
 
-        if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+        if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, false)) {
           return ProcessQueryResult::Drop;
         }
         ++g_stats.responses;
@@ -1394,7 +1410,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
     return ProcessQueryResult::PassToBackend;
   }
   catch (const std::exception& e){
-    vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what());
+    vinfolog("Got an error while parsing a %s query (after applying rules)  from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what());
   }
   return ProcessQueryResult::Drop;
 }
@@ -1402,7 +1418,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
 class UDPTCPCrossQuerySender : public TCPQuerySender
 {
 public:
-  UDPTCPCrossQuerySender(const ClientState& cs, const std::shared_ptr<DownstreamState>& ds): d_cs(cs), d_ds(ds)
+  UDPTCPCrossQuerySender()
   {
   }
 
@@ -1415,14 +1431,9 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() const override
-  {
-    return &d_cs;
-  }
-
   void handleResponse(const struct timeval& now, TCPResponse&& response) override
   {
-    if (!d_ds && !response.d_selfGenerated) {
+    if (!response.d_ds && !response.d_idstate.selfGenerated) {
       throw std::runtime_error("Passing a cross-protocol answer originated from UDP without a valid downstream");
     }
 
@@ -1431,7 +1442,7 @@ public:
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
 
-    handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated);
+    handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, response.d_ds, response.isAsync(), response.d_idstate.selfGenerated);
   }
 
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
@@ -1443,23 +1454,23 @@ public:
   {
     // nothing to do
   }
-private:
-  const ClientState& d_cs;
-  const std::shared_ptr<DownstreamState> d_ds{nullptr};
 };
 
 class UDPCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  UDPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState>& ds)
+  UDPCrossProtocolQuery(PacketBuffer&& buffer_, InternalQueryState&& ids_, std::shared_ptr<DownstreamState> ds): CrossProtocolQuery(InternalQuery(std::move(buffer_), std::move(ids_)), ds)
   {
-    uint16_t z = 0;
-    getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(buffer.data()), buffer.size(), &ids.udpPayloadSize, &z);
-    if (ids.udpPayloadSize < 512) {
-      ids.udpPayloadSize = 512;
+    auto& ids = query.d_idstate;
+    const auto& buffer = query.d_buffer;
+
+    if (ids.udpPayloadSize == 0) {
+      uint16_t z = 0;
+      getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(buffer.data()), buffer.size(), &ids.udpPayloadSize, &z);
+      if (ids.udpPayloadSize < 512) {
+        ids.udpPayloadSize = 512;
+      }
     }
-    query = InternalQuery(std::move(buffer), std::move(ids));
-    downstream = ds;
   }
 
   ~UDPCrossProtocolQuery()
@@ -1468,11 +1479,48 @@ public:
 
   std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
   {
-    auto sender = std::make_shared<UDPTCPCrossQuerySender>(*query.d_idstate.cs, downstream);
-    return sender;
+    return s_sender;
   }
+private:
+  static std::shared_ptr<UDPTCPCrossQuerySender> s_sender;
 };
 
+std::shared_ptr<UDPTCPCrossQuerySender> UDPCrossProtocolQuery::s_sender = std::make_shared<UDPTCPCrossQuerySender>();
+
+std::unique_ptr<CrossProtocolQuery> getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq);
+std::unique_ptr<CrossProtocolQuery> getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq)
+{
+  dq.ids.origID = dq.getHeader()->id;
+  return std::make_unique<UDPCrossProtocolQuery>(std::move(dq.getMutableData()), std::move(dq.ids), nullptr);
+}
+
+ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
+{
+  const uint16_t queryId = ntohs(dq.getHeader()->id);
+
+  try {
+    /* we need an accurate ("real") value for the response and
+       to store into the IDS, but not for insertion into the
+       rings for example */
+    struct timespec now;
+    gettime(&now);
+
+    if (!applyRulesToQuery(holders, dq, now)) {
+      return ProcessQueryResult::Drop;
+    }
+
+    if (dq.isAsynchronous()) {
+      return ProcessQueryResult::Asynchronous;
+    }
+
+    return processQueryAfterRules(dq, holders, selectedBackend);
+  }
+  catch (const std::exception& e){
+    vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what());
+  }
+  return ProcessQueryResult::Drop;
+}
+
 bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
 {
   bool doh = dq.ids.du != nullptr;
@@ -1601,9 +1649,9 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     std::shared_ptr<DownstreamState> ss{nullptr};
-    auto result = processQuery(dq, cs, holders, ss);
+    auto result = processQuery(dq, holders, ss);
 
-    if (result == ProcessQueryResult::Drop) {
+    if (result == ProcessQueryResult::Drop || result == ProcessQueryResult::Asynchronous) {
       return;
     }
 
@@ -1622,7 +1670,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
       sendUDPResponse(cs.udpFD, query, dq.ids.delayMsec, dest, remote);
 
-      handleResponseSent(ids, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP);
+      handleResponseSent(dq.ids.qname, dq.ids.qtype, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP, dnsdist::Protocol::DoUDP);
       return;
     }
 
@@ -2586,6 +2634,8 @@ int main(int argc, char** argv)
 #endif
     }
 
+    dnsdist::g_asyncHolder = std::make_unique<dnsdist::AsynchronousHolder>();
+
     auto todo = setupLua(*(g_lua.lock()), false, false, g_cmdLine.config);
 
     auto localPools = g_pools.getCopy();
index a425018e6ecd0af4ddf9ed83c69460b5cbf94c86..0741ca266ea901d5acd710ef10a84d36f83aa0d4 100644 (file)
@@ -61,6 +61,10 @@ extern bool g_ECSOverride;
 
 using QTag = std::unordered_map<string, string>;
 
+class IncomingTCPConnectionState;
+
+struct ClientState;
+
 struct DNSQuestion
 {
   DNSQuestion(InternalQueryState& ids_, PacketBuffer& data_):
@@ -69,6 +73,7 @@ struct DNSQuestion
   DNSQuestion(const DNSQuestion&) = delete;
   DNSQuestion& operator=(const DNSQuestion&) = delete;
   DNSQuestion(DNSQuestion&&) = default;
+  virtual ~DNSQuestion() = default;
 
   std::string getTrailingData() const;
   bool setTrailingData(const std::string&);
@@ -139,6 +144,21 @@ struct DNSQuestion
     return ids.queryRealTime.d_start;
   }
 
+  bool isAsynchronous() const
+  {
+    return asynchronous;
+  }
+
+  std::shared_ptr<IncomingTCPConnectionState> getIncomingTCPState() const
+  {
+    return d_incomingTCPState;
+  }
+
+  ClientState* getFrontend() const
+  {
+    return ids.cs;
+  }
+
 protected:
   PacketBuffer& data;
 
@@ -147,14 +167,18 @@ public:
   std::unique_ptr<Netmask> ecs{nullptr};
   std::string sni; /* Server Name Indication, if any (DoT or DoH) */
   mutable std::unique_ptr<EDNSOptionViewMap> ednsOptions; /* this needs to be mutable because it is parsed just in time, when DNSQuestion is read-only */
+  std::shared_ptr<IncomingTCPConnectionState> d_incomingTCPState{nullptr};
   std::unique_ptr<std::vector<ProxyProtocolValue>> proxyProtocolValues{nullptr};
   uint16_t ecsPrefixLength;
   uint8_t ednsRCode{0};
   bool ecsOverride;
   bool useECS{true};
   bool addXPF{true};
+  bool asynchronous{false};
 };
 
+struct DownstreamState;
+
 struct DNSResponse : DNSQuestion
 {
   DNSResponse(InternalQueryState& ids_, PacketBuffer& data_, const std::shared_ptr<DownstreamState>& downstream):
@@ -1183,8 +1207,6 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength);
-bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& insertedRespRuleActions, DNSResponse& dr, bool muted);
-bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop);
 
 bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs);
 
@@ -1203,11 +1225,16 @@ extern std::set<std::string> g_capabilitiesToRetain;
 static const uint16_t s_udpIncomingBufferSize{1500}; // don't accept UDP queries larger than this value
 static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses larger than this value
 
-enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend };
-ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
+enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend, Asynchronous };
+ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
+ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
+bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& insertedRespRuleActions, DNSResponse& dr, bool muted);
+bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop);
+bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted);
 
 bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest);
 
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
+bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote);
 void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol);
 void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol);
index ef7ac181a693357d7b8f24b686a4baadb2febc16..7310e4050de0ad7023527064f0cf6930d7b68e46 100644 (file)
@@ -134,6 +134,7 @@ dnsdist_SOURCES = \
        dns.cc dns.hh \
        dns_random.hh \
        dnscrypt.cc dnscrypt.hh \
+       dnsdist-async.cc dnsdist-async.hh \
        dnsdist-backend.cc \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-carbon.cc dnsdist-carbon.hh \
@@ -147,6 +148,7 @@ dnsdist_SOURCES = \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-healthchecks.cc dnsdist-healthchecks.hh \
        dnsdist-idstate.hh \
+       dnsdist-internal-queries.cc dnsdist-internal-queries.hh \
        dnsdist-kvs.hh dnsdist-kvs.cc \
        dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \
        dnsdist-lua-actions.cc \
@@ -242,6 +244,7 @@ testrunner_SOURCES = \
        credentials.cc credentials.hh \
        dns.cc dns.hh \
        dnscrypt.cc dnscrypt.hh \
+       dnsdist-async.cc dnsdist-async.hh \
        dnsdist-backend.cc \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
@@ -304,6 +307,7 @@ testrunner_SOURCES = \
        test-dnsdist-connections-cache.cc \
        test-dnsdist-dnsparser.cc \
        test-dnsdist_cc.cc \
+       test-dnsdistasync.cc \
        test-dnsdistbackend_cc.cc \
        test-dnsdistdynblocks_hh.cc \
        test-dnsdistkvs_cc.cc \
diff --git a/pdns/dnsdistdist/dnsdist-async.cc b/pdns/dnsdistdist/dnsdist-async.cc
new file mode 100644 (file)
index 0000000..fc91744
--- /dev/null
@@ -0,0 +1,428 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "dnsdist-async.hh"
+#include "dnsdist-internal-queries.hh"
+#include "dolog.hh"
+#include "threadname.hh"
+
+namespace dnsdist
+{
+
+AsynchronousHolder::AsynchronousHolder(bool failOpen)
+{
+  d_data = std::make_shared<Data>();
+  d_data->d_failOpen = failOpen;
+
+  int fds[2] = {-1, -1};
+  if (pipe(fds) < 0) {
+    throw std::runtime_error("Error creating the AsynchronousHolder pipe: " + stringerror());
+  }
+
+  for (size_t idx = 0; idx < (sizeof(fds) / sizeof(*fds)); idx++) {
+    if (!setNonBlocking(fds[idx])) {
+      int err = errno;
+      close(fds[0]);
+      close(fds[1]);
+      throw std::runtime_error("Error setting the AsynchronousHolder pipe non-blocking: " + stringerror(err));
+    }
+  }
+
+  d_data->d_notifyPipe = FDWrapper(fds[1]);
+  d_data->d_watchPipe = FDWrapper(fds[0]);
+
+  std::thread main([data = this->d_data] { mainThread(data); });
+  main.detach();
+}
+
+AsynchronousHolder::~AsynchronousHolder()
+{
+  try {
+    stop();
+  }
+  catch (...) {
+  }
+}
+
+bool AsynchronousHolder::notify() const
+{
+  const char data = 0;
+  bool failed = false;
+  do {
+    auto written = write(d_data->d_notifyPipe.getHandle(), &data, sizeof(data));
+    if (written == 0) {
+      break;
+    }
+    if (written > 0 && static_cast<size_t>(written) == sizeof(data)) {
+      return true;
+    }
+    if (errno != EINTR) {
+      failed = true;
+    }
+  } while (!failed);
+
+  return false;
+}
+
+bool AsynchronousHolder::wait(const AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs)
+{
+  readyFDs.clear();
+  mplexer.getAvailableFDs(readyFDs, atMostMs);
+  if (readyFDs.size() == 0) {
+    /* timeout */
+    return true;
+  }
+
+  while (true) {
+    /* we might have been notified several times, let's read
+       as much as possible before returning */
+    char dummy = 0;
+    auto got = read(data.d_watchPipe.getHandle(), &dummy, sizeof(dummy));
+    if (got == 0) {
+      break;
+    }
+    if (got > 0 && static_cast<size_t>(got) != sizeof(dummy)) {
+      continue;
+    }
+    if (got == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+      break;
+    }
+  }
+
+  return false;
+}
+
+void AsynchronousHolder::stop()
+{
+  {
+    auto content = d_data->d_content.lock();
+    d_data->d_done = true;
+  }
+
+  notify();
+}
+
+void AsynchronousHolder::mainThread(std::shared_ptr<Data> data)
+{
+  setThreadName("dnsdist/async");
+  struct timeval now;
+  std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>> expiredEvents;
+
+  auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(1));
+  mplexer->addReadFD(data->d_watchPipe.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
+  std::vector<int> readyFDs;
+
+  while (true) {
+    bool shouldWait = true;
+    int timeout = -1;
+    {
+      auto content = data->d_content.lock();
+      if (data->d_done) {
+        return;
+      }
+
+      if (!content->empty()) {
+        gettimeofday(&now, nullptr);
+        struct timeval next = getNextTTD(*content);
+        if (next <= now) {
+          pickupExpired(*content, now, expiredEvents);
+          shouldWait = false;
+        }
+        else {
+          auto remainingUsec = uSec(next - now);
+          timeout = std::round(remainingUsec / 1000.0);
+          if (timeout == 0 && remainingUsec > 0) {
+            /* if we have less than 1 ms, let's wait at least 1 ms */
+            timeout = 1;
+          }
+        }
+      }
+    }
+
+    if (shouldWait) {
+      auto timedOut = wait(*data, *mplexer, readyFDs, timeout);
+      if (timedOut) {
+        auto content = data->d_content.lock();
+        gettimeofday(&now, nullptr);
+        pickupExpired(*content, now, expiredEvents);
+      }
+    }
+
+    while (!expiredEvents.empty()) {
+      auto [queryID, query] = std::move(expiredEvents.front());
+      expiredEvents.pop_front();
+      if (!data->d_failOpen) {
+        vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec);
+        auto sender = query->getTCPQuerySender();
+        if (sender) {
+          sender->notifyIOError(std::move(query->query.d_idstate), now);
+        }
+      }
+      else {
+        vinfolog("Asynchronous query %d has expired at %d.%d, resuming", queryID, now.tv_sec, now.tv_usec);
+        resumeQuery(std::move(query));
+      }
+    }
+  }
+}
+
+void AsynchronousHolder::push(uint16_t asyncID, uint16_t queryID, const struct timeval& ttd, std::unique_ptr<CrossProtocolQuery>&& query)
+{
+  bool needNotify = false;
+  {
+    auto content = d_data->d_content.lock();
+    if (!content->empty()) {
+      /* the thread is already waiting on a TTD expiry in addition to notifications,
+         let's not wake it unless our TTD comes before the current one */
+      const struct timeval next = getNextTTD(*content);
+      if (ttd < next) {
+        needNotify = true;
+      }
+    }
+    else {
+      /* the thread is currently only waiting for a notify */
+      needNotify = true;
+    }
+    content->insert({std::move(query), ttd, asyncID, queryID});
+  }
+
+  if (needNotify) {
+    notify();
+  }
+}
+
+std::unique_ptr<CrossProtocolQuery> AsynchronousHolder::get(uint16_t asyncID, uint16_t queryID)
+{
+  /* no need to notify, worst case the thread wakes up for nothing because this was the next TTD */
+  auto content = d_data->d_content.lock();
+  auto it = content->find(std::tie(queryID, asyncID));
+  if (it == content->end()) {
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+    vinfolog("Asynchronous object %d not found at %d.%d", queryID, now.tv_sec, now.tv_usec);
+    return nullptr;
+  }
+
+  auto result = std::move(it->d_query);
+  content->erase(it);
+  return result;
+}
+
+void AsynchronousHolder::pickupExpired(content_t& content, const struct timeval& now, std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>>& events)
+{
+  auto& idx = content.get<TTDTag>();
+  for (auto it = idx.begin(); it != idx.end() && it->d_ttd < now;) {
+    events.emplace_back(it->d_queryID, std::move(it->d_query));
+    it = idx.erase(it);
+  }
+}
+
+struct timeval AsynchronousHolder::getNextTTD(const content_t& content)
+{
+  if (content.empty()) {
+    throw std::runtime_error("AsynchronousHolder::getNextTTD() called on an empty holder");
+  }
+
+  return content.get<TTDTag>().begin()->d_ttd;
+}
+
+bool AsynchronousHolder::empty()
+{
+  return d_data->d_content.read_only_lock()->empty();
+}
+
+static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response)
+{
+  try {
+    auto& ids = response->query.d_idstate;
+    DNSResponse dr = response->getDR();
+
+    LocalHolders holders;
+    auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dr, ids.cs->muted);
+    if (!result) {
+      /* easy */
+      return true;
+    }
+
+    auto sender = response->getTCPQuerySender();
+    if (sender) {
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      TCPResponse resp(std::move(response->query.d_buffer), std::move(response->query.d_idstate), nullptr, response->downstream);
+      resp.d_async = true;
+      sender->handleResponse(now, std::move(resp));
+    }
+  }
+  catch (const std::exception& e) {
+    vinfolog("Got exception while resuming cross-protocol response: %s", e.what());
+    return false;
+  }
+
+  return true;
+}
+
+static LockGuarded<std::deque<std::unique_ptr<CrossProtocolQuery>>> s_asynchronousEventsQueue;
+
+bool queueQueryResumptionEvent(std::unique_ptr<CrossProtocolQuery>&& query)
+{
+  s_asynchronousEventsQueue.lock()->push_back(std::move(query));
+  return true;
+}
+
+void handleQueuedAsynchronousEvents()
+{
+  while (true) {
+    std::unique_ptr<CrossProtocolQuery> query;
+    {
+      // we do not want to hold the lock while resuming
+      auto queue = s_asynchronousEventsQueue.lock();
+      if (queue->empty()) {
+        return;
+      }
+
+      query = std::move(queue->front());
+      queue->pop_front();
+    }
+    if (query && !resumeQuery(std::move(query))) {
+      vinfolog("Unable to resume asynchronous query event");
+    }
+  }
+}
+
+bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
+{
+  if (query->d_isResponse) {
+    return resumeResponse(std::move(query));
+  }
+
+  auto& ids = query->query.d_idstate;
+  DNSQuestion dq = query->getDQ();
+  LocalHolders holders;
+
+  auto result = processQueryAfterRules(dq, holders, query->downstream);
+  if (result == ProcessQueryResult::Drop) {
+    /* easy */
+    return true;
+  }
+  else if (result == ProcessQueryResult::PassToBackend) {
+    if (query->downstream == nullptr) {
+      return false;
+    }
+
+#ifdef HAVE_DNS_OVER_HTTPS
+    if (dq.ids.du != nullptr) {
+      dq.ids.du->downstream = query->downstream;
+    }
+#endif
+
+    if (query->downstream->isTCPOnly() || !(dq.getProtocol().isUDP() || dq.getProtocol() == dnsdist::Protocol::DoH)) {
+      query->downstream->passCrossProtocolQuery(std::move(query));
+      return true;
+    }
+
+    auto queryID = dq.getHeader()->id;
+    /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery
+       which will stop existing when we return, so we need to increment the reference count
+    */
+    return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dq, query->query.d_buffer, ids.origDest);
+  }
+  else if (result == ProcessQueryResult::SendAnswer) {
+    auto sender = query->getTCPQuerySender();
+    if (!sender) {
+      return false;
+    }
+
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+
+    TCPResponse response(std::move(query->query.d_buffer), std::move(query->query.d_idstate), nullptr, query->downstream);
+    response.d_async = true;
+    response.d_idstate.selfGenerated = true;
+
+    try {
+      sender->handleResponse(now, std::move(response));
+      return true;
+    }
+    catch (const std::exception& e) {
+      vinfolog("Got exception while resuming cross-protocol self-answered query: %s", e.what());
+      return false;
+    }
+  }
+  else if (result == ProcessQueryResult::Asynchronous) {
+    /* nope */
+    errlog("processQueryAfterRules returned 'asynchronous' while trying to resume an already asynchronous query");
+    return false;
+  }
+
+  return false;
+}
+
+bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+{
+  if (!g_asyncHolder) {
+    return false;
+  }
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  struct timeval ttd = now;
+  ttd.tv_sec += timeoutMs / 1000;
+  ttd.tv_usec += (timeoutMs % 1000) * 1000;
+  if (ttd.tv_usec >= 1000000) {
+    ttd.tv_sec++;
+    ttd.tv_usec -= 1000000;
+  }
+
+  vinfolog("Suspending asynchronous query %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec);
+  auto query = getInternalQueryFromDQ(dq, false);
+
+  g_asyncHolder->push(asyncID, queryID, ttd, std::move(query));
+  return true;
+}
+
+bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+{
+  if (!g_asyncHolder) {
+    return false;
+  }
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  struct timeval ttd = now;
+  ttd.tv_sec += timeoutMs / 1000;
+  ttd.tv_usec += (timeoutMs % 1000) * 1000;
+  if (ttd.tv_usec >= 1000000) {
+    ttd.tv_sec++;
+    ttd.tv_usec -= 1000000;
+  }
+
+  vinfolog("Suspending asynchronous response %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec);
+  auto query = getInternalQueryFromDQ(dr, true);
+  query->d_isResponse = true;
+  query->downstream = dr.d_downstream;
+
+  g_asyncHolder->push(asyncID, queryID, ttd, std::move(query));
+  return true;
+}
+
+std::unique_ptr<AsynchronousHolder> g_asyncHolder;
+}
diff --git a/pdns/dnsdistdist/dnsdist-async.hh b/pdns/dnsdistdist/dnsdist-async.hh
new file mode 100644 (file)
index 0000000..5a8c090
--- /dev/null
@@ -0,0 +1,98 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <thread>
+
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
+
+#include "dnsdist-tcp.hh"
+
+namespace dnsdist
+{
+class AsynchronousHolder
+{
+public:
+  AsynchronousHolder(bool failOpen = true);
+  ~AsynchronousHolder();
+  void push(uint16_t asyncID, uint16_t queryID, const struct timeval& ttd, std::unique_ptr<CrossProtocolQuery>&& query);
+  std::unique_ptr<CrossProtocolQuery> get(uint16_t asyncID, uint16_t queryID);
+  bool empty();
+  void stop();
+
+private:
+  struct TTDTag
+  {
+  };
+  struct IDTag
+  {
+  };
+
+  struct Entry
+  {
+    /* not used by any of the indexes, so mutable */
+    mutable std::unique_ptr<CrossProtocolQuery> d_query;
+    struct timeval d_ttd;
+    uint16_t d_asyncID;
+    uint16_t d_queryID;
+  };
+
+  typedef multi_index_container<
+    Entry,
+    indexed_by<
+      ordered_unique<tag<IDTag>,
+                     composite_key<
+                       Entry,
+                       member<Entry, uint16_t, &Entry::d_queryID>,
+                       member<Entry, uint16_t, &Entry::d_asyncID>>>,
+      ordered_non_unique<tag<TTDTag>,
+                         member<Entry, struct timeval, &Entry::d_ttd>>>>
+    content_t;
+
+  static void pickupExpired(content_t&, const struct timeval& now, std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>>& expiredEvents);
+  static struct timeval getNextTTD(const content_t&);
+
+  struct Data
+  {
+    LockGuarded<content_t> d_content;
+    FDWrapper d_notifyPipe;
+    FDWrapper d_watchPipe;
+    bool d_failOpen{true};
+    bool d_done{false};
+  };
+  std::shared_ptr<Data> d_data{nullptr};
+
+  static void mainThread(std::shared_ptr<Data> data);
+  static bool wait(const Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs);
+  bool notify() const;
+};
+
+bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs);
+bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs);
+bool queueQueryResumptionEvent(std::unique_ptr<CrossProtocolQuery>&& query);
+bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query);
+void handleQueuedAsynchronousEvents();
+
+extern std::unique_ptr<AsynchronousHolder> g_asyncHolder;
+}
index a6328f168cd00c2cf92f95382a55fdbdb70276ad..480bfd1960a0c4eeea02977b6557a714764d323a 100644 (file)
@@ -140,11 +140,6 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() const override
-  {
-    return nullptr;
-  }
-
   void handleResponse(const struct timeval& now, TCPResponse&& response) override
   {
     d_data->d_buffer = std::move(response.d_buffer);
diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.cc b/pdns/dnsdistdist/dnsdist-internal-queries.cc
new file mode 100644 (file)
index 0000000..49f95e4
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "dnsdist-internal-queries.hh"
+#include "dnsdist-tcp.hh"
+#include "doh.hh"
+
+std::unique_ptr<CrossProtocolQuery> getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq);
+
+namespace dnsdist
+{
+std::unique_ptr<CrossProtocolQuery> getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse)
+{
+  auto protocol = dq.getProtocol();
+  if (protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP) {
+    return getUDPCrossProtocolQueryFromDQ(dq);
+  }
+#ifdef HAVE_DNS_OVER_HTTPS
+  else if (protocol == dnsdist::Protocol::DoH) {
+    return getDoHCrossProtocolQueryFromDQ(dq, isResponse);
+  }
+#endif
+  else {
+    return getTCPCrossProtocolQueryFromDQ(dq);
+  }
+}
+}
diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.hh b/pdns/dnsdistdist/dnsdist-internal-queries.hh
new file mode 100644 (file)
index 0000000..46634aa
--- /dev/null
@@ -0,0 +1,30 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <memory>
+#include "dnsdist.hh"
+
+namespace dnsdist
+{
+std::unique_ptr<CrossProtocolQuery> getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse);
+}
index 3d2c21c4fb694cbc9d966904e8aa029d6cf2935b..e66a13986ac045b61c2182ec517ed2eae5456253 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist.hh"
+#include "dnsdist-async.hh"
 #include "dnsdist-lua.hh"
 #include "dnsdist-lua-network.hh"
 #include "dolog.hh"
@@ -66,8 +67,11 @@ void setupLuaBindingsNetwork(LuaContext& luaCtx, bool client)
     }
 
     return listener->addUnixListeningEndpoint(path, endpointID, [cb](dnsdist::NetworkListener::EndpointID endpoint, std::string&& dgram, const std::string& from) {
-      auto lock = g_lua.lock();
-      cb(endpoint, dgram, from);
+      {
+        auto lock = g_lua.lock();
+        cb(endpoint, dgram, from);
+      }
+      dnsdist::handleQueuedAsynchronousEvents();
     });
   });
 
index 3a509058a2a9f4038deb0e50dd6cdd39fa737f6f..741bd3aedc3c180bfe57a2193f942274b61343b7 100644 (file)
@@ -67,6 +67,7 @@ void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq,
 size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+uint16_t dnsdist_ffi_dnsquestion_get_id(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
@@ -85,11 +86,13 @@ uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquest
 bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label) __attribute__ ((visibility ("default")));
+size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* label, char* buffer, size_t bufferSize) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, void* buffer, size_t bufferSize) __attribute__ ((visibility ("default")));
+uint64_t dnsdist_ffi_dnsquestion_get_elapsed_us(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 
 // returns the length of the resulting 'out' array. 'out' is not set if the length is 0
 size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_ednsoption_t** out) __attribute__ ((visibility ("default")));
@@ -106,6 +109,7 @@ void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq
 void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) __attribute__ ((visibility ("default")));
+void dnsdist_ffi_dnsquestion_set_tag_raw(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value, size_t valueSize) __attribute__ ((visibility ("default")));
 
 void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType) __attribute__ ((visibility ("default")));
 
@@ -148,6 +152,14 @@ void dnsdist_ffi_dnsresponse_limit_ttl(dnsdist_ffi_dnsresponse_t* dr, uint32_t m
 void dnsdist_ffi_dnsresponse_set_max_returned_ttl(dnsdist_ffi_dnsresponse_t* dr, uint32_t max) __attribute__ ((visibility ("default")));
 void dnsdist_ffi_dnsresponse_clear_records_type(dnsdist_ffi_dnsresponse_t* dr, uint16_t qtype) __attribute__ ((visibility ("default")));
 
+bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) __attribute__ ((visibility ("default")));
+
+bool dnsdist_ffi_resume_from_async(uint16_t asyncID, uint16_t queryID, const char* tag, size_t tagSize, const char* tagValue, size_t tagValueSize, bool useCache) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const char* raw, size_t rawSize) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_t rcode, bool clearAnswers) __attribute__ ((visibility ("default")));
+
 typedef struct dnsdist_ffi_proxy_protocol_value {
   const char* value;
   uint16_t size;
index a0f782710a069a151265976c91005b5f5d782a8f..f204b0aca09d7063bf3c1a82ff4dcc531d003b4e 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 
+#include "dnsdist-async.hh"
 #include "dnsdist-dnsparser.hh"
 #include "dnsdist-lua-ffi.hh"
 #include "dnsdist-mac-address.hh"
@@ -39,6 +40,14 @@ uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq)
   return dq->dq->ids.qclass;
 }
 
+uint16_t dnsdist_ffi_dnsquestion_get_id(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq == nullptr) {
+    return 0;
+  }
+  return ntohs(dq->dq->getHeader()->id);
+}
+
 static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** addr, size_t* addrSize)
 {
   if (ca.isIPv4()) {
@@ -66,7 +75,6 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq,
   if (dq == nullptr) {
     return 0;
   }
-
   auto ret = dnsdist::MacAddressesCache::get(dq->dq->ids.origRemote, reinterpret_cast<unsigned char*>(buffer), bufferSize);
   if (ret != 0) {
     return 0;
@@ -75,6 +83,15 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq,
   return 6;
 }
 
+uint64_t dnsdist_ffi_dnsquestion_get_elapsed_us(const dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq == nullptr) {
+    return 0;
+  }
+
+  return dq->dq->ids.queryRealTime.udiff();
+}
+
 void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits)
 {
   dq->maskedRemote = Netmask(dq->dq->ids.origRemote, bits).getMaskedNetwork();
@@ -223,7 +240,7 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq,
 {
   const char * result = nullptr;
 
-  if (dq->dq->ids.qTag != nullptr) {
+  if (dq != nullptr && dq->dq != nullptr && dq->dq->ids.qTag != nullptr) {
     const auto it = dq->dq->ids.qTag->find(label);
     if (it != dq->dq->ids.qTag->cend()) {
       result = it->second.c_str();
@@ -233,6 +250,25 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq,
   return result;
 }
 
+size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* label, char* buffer, size_t bufferSize)
+{
+  if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.qTag == nullptr || label == nullptr || buffer == nullptr || bufferSize == 0) {
+    return 0;
+  }
+
+  const auto it = dq->dq->ids.qTag->find(label);
+  if (it == dq->dq->ids.qTag->cend()) {
+    return 0;
+  }
+
+  if (it->second.size() > bufferSize) {
+    return 0;
+  }
+
+  memcpy(buffer, it->second.c_str(), it->second.size());
+  return it->second.size();
+}
+
 const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq)
 {
   if (!dq->httpPath) {
@@ -380,7 +416,7 @@ size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, c
 
 size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out)
 {
-  if (dq->dq->ids.qTag == nullptr || dq->dq->ids.qTag->size() == 0) {
+  if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.qTag == nullptr || dq->dq->ids.qTag->size() == 0) {
     return 0;
   }
 
@@ -470,6 +506,11 @@ void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char*
   dq->dq->setTag(label, value);
 }
 
+void dnsdist_ffi_dnsquestion_set_tag_raw(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value, size_t valueSize)
+{
+  dq->dq->setTag(label, std::string(value, valueSize));
+}
+
 size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out)
 {
   dq->trailingData = dq->dq->getTrailingData();
@@ -649,6 +690,168 @@ void dnsdist_ffi_dnsresponse_clear_records_type(dnsdist_ffi_dnsresponse_t* dr, u
   }
 }
 
+bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+{
+  try {
+    dq->dq->asynchronous = true;
+    dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs);
+    return true;
+  }
+  catch (const std::exception& e) {
+    vinfolog("Error in dnsdist_ffi_dnsquestion_set_async: %s", e.what());
+  }
+  catch (...) {
+    vinfolog("Exception in dnsdist_ffi_dnsquestion_set_async");
+  }
+
+  return false;
+}
+
+bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+{
+  try {
+    dq->dq->asynchronous = true;
+    auto dr = dynamic_cast<DNSResponse*>(dq->dq);
+    if (!dr) {
+      vinfolog("Passed a DNSQuestion instead of a DNSResponse to dnsdist_ffi_dnsresponse_set_async");
+      return false;
+    }
+
+    dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs);
+    return true;
+  }
+  catch (const std::exception& e) {
+    vinfolog("Error in dnsdist_ffi_dnsresponse_set_async: %s", e.what());
+  }
+  catch (...) {
+    vinfolog("Exception in dnsdist_ffi_dnsresponse_set_async");
+  }
+  return false;
+}
+
+bool dnsdist_ffi_resume_from_async(uint16_t asyncID, uint16_t queryID, const char* tag, size_t tagSize, const char* tagValue, size_t tagValueSize, bool useCache)
+{
+  if (!dnsdist::g_asyncHolder) {
+    vinfolog("Unable to resume, no asynchronous holder");
+    return false;
+  }
+
+  auto query = dnsdist::g_asyncHolder->get(asyncID, queryID);
+  if (!query) {
+    vinfolog("Unable to resume, no object found for asynchronous ID %d and query ID %d", asyncID, queryID);
+    return false;
+  }
+
+  auto& ids = query->query.d_idstate;
+  if (tag != nullptr && tagSize > 0) {
+    if (!ids.qTag) {
+      ids.qTag = std::make_unique<QTag>();
+    }
+    (*ids.qTag)[std::string(tag, tagSize)] = std::string(tagValue, tagValueSize);
+  }
+
+  ids.skipCache = !useCache;
+
+  return dnsdist::queueQueryResumptionEvent(std::move(query));
+}
+
+bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_t rcode, bool clearAnswers)
+{
+  if (!dnsdist::g_asyncHolder) {
+    return false;
+  }
+
+  auto query = dnsdist::g_asyncHolder->get(asyncID, queryID);
+  if (!query) {
+    vinfolog("Unable to resume with a custom response code, no object found for asynchronous ID %d and query ID %d", asyncID, queryID);
+    return false;
+  }
+
+  const auto qnameLength = query->query.d_idstate.qname.wirelength();
+  auto& buffer = query->query.d_buffer;
+  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
+    return false;
+  }
+
+  EDNS0Record edns0;
+  bool hadEDNS = false;
+  if (clearAnswers) {
+    hadEDNS = getEDNS0Record(buffer, edns0);
+  }
+
+  auto dh = reinterpret_cast<dnsheader*>(buffer.data());
+  dh->rcode = rcode;
+  dh->ad = false;
+  dh->aa = false;
+  dh->ra = dh->rd;
+  dh->qr = true;
+
+  if (clearAnswers) {
+    dh->ancount = 0;
+    dh->nscount = 0;
+    dh->arcount = 0;
+    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
+    if (hadEDNS) {
+      if (!addEDNS(buffer, query->query.d_idstate.protocol.isUDP() ? 4096 : std::numeric_limits<uint16_t>::max(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) {
+        return false;
+      }
+    }
+  }
+
+  query->query.d_idstate.skipCache = true;
+
+  return dnsdist::queueQueryResumptionEvent(std::move(query));
+}
+
+bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID)
+{
+  if (!dnsdist::g_asyncHolder) {
+    return false;
+  }
+
+  auto query = dnsdist::g_asyncHolder->get(asyncID, queryID);
+  if (!query) {
+    vinfolog("Unable to drop, no object found for asynchronous ID %d and query ID %d", asyncID, queryID);
+    return false;
+  }
+
+  auto sender = query->getTCPQuerySender();
+  if (!sender) {
+    return false;
+  }
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  sender->notifyIOError(std::move(query->query.d_idstate), now);
+
+  return true;
+}
+
+bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const char* raw, size_t rawSize)
+{
+  if (rawSize < sizeof(dnsheader)) {
+    return false;
+  }
+  if (!dnsdist::g_asyncHolder) {
+    return false;
+  }
+
+  auto query = dnsdist::g_asyncHolder->get(asyncID, queryID);
+  if (!query) {
+    vinfolog("Unable to resume with a custom answer, no object found for asynchronous ID %d and query ID %d", asyncID, queryID);
+    return false;
+  }
+
+  auto oldId = reinterpret_cast<const dnsheader*>(query->query.d_buffer.data())->id;
+  query->query.d_buffer.clear();
+  query->query.d_buffer.insert(query->query.d_buffer.begin(), raw, raw + rawSize);
+  reinterpret_cast<dnsheader*>(query->query.d_buffer.data())->id = oldId;
+
+  query->query.d_idstate.skipCache = true;
+
+  return dnsdist::queueQueryResumptionEvent(std::move(query));
+}
+
 static constexpr char s_lua_ffi_code[] = R"FFICodeContent(
   local ffi = require("ffi")
   local C = ffi.C
@@ -797,6 +1000,10 @@ size_t dnsdist_ffi_packetcache_get_domain_list_by_addr(const char* poolName, con
     vinfolog("Error parsing address passed to dnsdist_ffi_packetcache_get_domain_list_by_addr: %s", e.what());
     return 0;
   }
+  catch (const PDNSException& e) {
+    vinfolog("Error parsing address passed to dnsdist_ffi_packetcache_get_domain_list_by_addr: %s", e.reason);
+    return 0;
+  }
 
   const auto localPools = g_pools.getCopy();
   auto it = localPools.find(poolName);
@@ -1037,6 +1244,10 @@ size_t dnsdist_ffi_ring_get_entries_by_addr(const char* addr, dnsdist_ffi_ring_e
     vinfolog("Unable to convert address in dnsdist_ffi_ring_get_entries_by_addr: %s", e.what());
     return 0;
   }
+  catch (const PDNSException& e) {
+    vinfolog("Unable to convert address in dnsdist_ffi_ring_get_entries_by_addr: %s", e.reason);
+    return 0;
+  }
 
   auto list = std::make_unique<dnsdist_ffi_ring_entry_list_t>();
 
index 819137a400fc50081a0a181f9be0852db4eb75bb..68928880089f86a835d9a3e83c0dc2fff2751614 100644 (file)
 
 #include "dnsdist-lua-network.hh"
 #include "dolog.hh"
+#include "threadname.hh"
 
 namespace dnsdist
 {
 NetworkListener::NetworkListener() :
-  d_mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  d_mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(10)))
 {
 }
 
@@ -131,6 +132,7 @@ void NetworkListener::runOnce(struct timeval& now, uint32_t timeout)
 
 void NetworkListener::mainThread()
 {
+  setThreadName("dnsdist/lua-net");
   struct timeval now;
 
   while (true) {
index 56884d0f31bca98bf63527e0eb77d5505e7d3840..5c745eac1c1381c4ceedf8a92bd21876bc9706ac 100644 (file)
@@ -148,7 +148,7 @@ void DoHConnectionToBackend::handleResponse(PendingRequest&& request)
       }
     }
 
-    request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this()));
+    request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this(), d_ds));
   }
   catch (const std::exception& e) {
     vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what());
index 9da42acfb7f2e17ae18663e41e9bc2c5699075fb..a6ab7002f21bef9cc8774798c6d1a800cd6e2e74 100644 (file)
@@ -38,7 +38,7 @@ ConnectionToBackend::~ConnectionToBackend()
 bool ConnectionToBackend::reconnect()
 {
   std::unique_ptr<TLSSession> tlsSession{nullptr};
-  if (d_handler) { 
+  if (d_handler) {
     DEBUGLOG("closing socket "<<d_handler->getDescriptor());
     if (d_handler->isTLS()) {
       if (d_handler->hasTLSSessionBeenResumed()) {
@@ -73,18 +73,18 @@ bool ConnectionToBackend::reconnect()
     DEBUGLOG("Opening TCP connection to backend "<<d_ds->getNameWithAddr());
     ++d_ds->tcpNewConnections;
     try {
-      auto socket = std::make_unique<Socket>(d_ds->d_config.remote.sin4.sin_family, SOCK_STREAM, 0);
-      DEBUGLOG("result of socket() is "<<socket->getHandle());
+      auto socket = Socket(d_ds->d_config.remote.sin4.sin_family, SOCK_STREAM, 0);
+      DEBUGLOG("result of socket() is "<<socket.getHandle());
 
       /* disable NAGLE, which does not play nicely with delayed ACKs.
          In theory we could be wasting up to 500 milliseconds waiting for
          the other end to acknowledge our initial packet before we could
          send the rest. */
-      setTCPNoDelay(socket->getHandle());
+      setTCPNoDelay(socket.getHandle());
 
 #ifdef SO_BINDTODEVICE
       if (!d_ds->d_config.sourceItfName.empty()) {
-        int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->d_config.sourceItfName.c_str(), d_ds->d_config.sourceItfName.length());
+        int res = setsockopt(socket.getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->d_config.sourceItfName.c_str(), d_ds->d_config.sourceItfName.length());
         if (res != 0) {
           vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
         }
@@ -92,19 +92,18 @@ bool ConnectionToBackend::reconnect()
 #endif
 
       if (!IsAnyAddress(d_ds->d_config.sourceAddr)) {
-        SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
+        SSetsockopt(socket.getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef IP_BIND_ADDRESS_NO_PORT
         if (d_ds->d_config.ipBindAddrNoPort) {
-          SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+          SSetsockopt(socket.getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
         }
 #endif
-        socket->bind(d_ds->d_config.sourceAddr, false);
+        socket.bind(d_ds->d_config.sourceAddr, false);
       }
-
-      socket->setNonBlocking();
+      socket.setNonBlocking();
 
       gettimeofday(&d_connectionStartTime, nullptr);
-      auto handler = std::make_unique<TCPIOHandler>(d_ds->d_config.d_tlsSubjectName, d_ds->d_config.d_tlsSubjectIsAddr, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec);
+      auto handler = std::make_unique<TCPIOHandler>(d_ds->d_config.d_tlsSubjectName, d_ds->d_config.d_tlsSubjectIsAddr, socket.releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec);
       if (!tlsSession && d_ds->d_tlsCtx) {
         tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec);
       }
@@ -591,15 +590,13 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
   auto pendingResponses = std::move(d_pendingResponses);
   d_pendingResponses.clear();
 
-  auto increaseCounters = [reason](std::shared_ptr<TCPQuerySender>& sender) {
+  auto increaseCounters = [reason](const ClientState* cs) {
     if (reason == FailureReason::timeout) {
-      const ClientState* cs = sender->getClientState();
       if (cs) {
         ++cs->tcpDownstreamTimeouts;
       }
     }
     else if (reason == FailureReason::gaveUp) {
-      const ClientState* cs = sender->getClientState();
       if (cs) {
         ++cs->tcpGaveUp;
       }
@@ -608,25 +605,25 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
 
   try {
     if (d_state == State::sendingQueryToBackend) {
+      increaseCounters(d_currentQuery.d_query.d_idstate.cs);
       auto sender = d_currentQuery.d_sender;
       if (sender->active()) {
-        increaseCounters(sender);
         sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now);
       }
     }
 
     for (auto& query : pendingQueries) {
+      increaseCounters(query.d_query.d_idstate.cs);
       auto sender = query.d_sender;
       if (sender->active()) {
-        increaseCounters(sender);
         sender->notifyIOError(std::move(query.d_query.d_idstate), now);
       }
     }
 
     for (auto& response : pendingResponses) {
+      increaseCounters(response.second.d_query.d_idstate.cs);
       auto sender = response.second.d_sender;
       if (sender->active()) {
-        increaseCounters(sender);
         sender->notifyIOError(std::move(response.second.d_query.d_idstate), now);
       }
     }
@@ -672,6 +669,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     TCPResponse response;
     response.d_buffer = std::move(d_responseBuffer);
     response.d_connection = conn;
+    response.d_ds = conn->d_ds;
     /* we don't move the whole IDS because we will need for the responses to come */
     response.d_idstate.qtype = it->second.d_query.d_idstate.qtype;
     response.d_idstate.qname = it->second.d_query.d_idstate.qname;
@@ -728,7 +726,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   if (sender->active()) {
     DEBUGLOG("passing response to client connection for "<<ids.qname);
     // make sure that we still exist after calling handleResponse()
-    sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn));
+    sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds));
   }
 
   if (!d_pendingQueries.empty()) {
index e48d52317dd4166582c9e35a51080406b27af2af..59c4df410d241882a7b6438d42c3bf4feaf811a4 100644 (file)
@@ -139,11 +139,6 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
     return d_ioState != nullptr;
   }
 
-  const ClientState* getClientState() const override
-  {
-    return d_ci.cs;
-  }
-
   std::string toString() const
   {
     ostringstream o;
index 43d615dea2c8e359879c89a17cfb6abf0ac0fd98..3d11f1a4f4975fd26e4024fb909ccad173922645 100644 (file)
@@ -72,8 +72,9 @@ struct ConnectionInfo
   int fd{-1};
 };
 
-struct InternalQuery
+class InternalQuery
 {
+public:
   InternalQuery()
   {
   }
@@ -119,15 +120,26 @@ struct TCPResponse : public TCPQuery
     memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
   }
 
-  TCPResponse(PacketBuffer&& buffer, InternalQueryState&& state, std::shared_ptr<ConnectionToBackend> conn) :
-    TCPQuery(std::move(buffer), std::move(state)), d_connection(conn)
+  TCPResponse(PacketBuffer&& buffer, InternalQueryState&& state, std::shared_ptr<ConnectionToBackend> conn, std::shared_ptr<DownstreamState> ds) :
+    TCPQuery(std::move(buffer), std::move(state)), d_connection(conn), d_ds(ds)
   {
-    memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+    if (d_buffer.size() >= sizeof(dnsheader)) {
+      memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+    }
+    else {
+      memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+    }
+  }
+
+  bool isAsync() const
+  {
+    return d_async;
   }
 
   std::shared_ptr<ConnectionToBackend> d_connection{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
   dnsheader d_cleartextDH;
-  bool d_selfGenerated{false};
+  bool d_async{false};
 };
 
 class TCPQuerySender
@@ -138,7 +150,6 @@ public:
   }
 
   virtual bool active() const = 0;
-  virtual const ClientState* getClientState() const = 0;
   virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
   virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
   virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0;
@@ -170,11 +181,24 @@ struct CrossProtocolQuery
   }
 
   virtual std::shared_ptr<TCPQuerySender> getTCPQuerySender() = 0;
+  virtual DNSQuestion getDQ()
+  {
+    auto& ids = query.d_idstate;
+    DNSQuestion dq(ids, query.d_buffer);
+    return dq;
+  }
+
+  virtual DNSResponse getDR()
+  {
+    auto& ids = query.d_idstate;
+    DNSResponse dr(ids, query.d_buffer, downstream);
+    return dr;
+  }
 
   InternalQuery query;
   std::shared_ptr<DownstreamState> downstream{nullptr};
   size_t proxyProtocolPayloadSize{0};
-  bool isXFR{false};
+  bool d_isResponse{false};
 };
 
 class TCPClientCollection
@@ -278,3 +302,5 @@ private:
 };
 
 extern std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
+
+std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq);
index cc4c549bf1d872d40689bc0a4112c5e6f6502acb..d505e1a0c199da62fd614e56de3f5728b32164ae 100644 (file)
@@ -433,7 +433,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo
 class DoHTCPCrossQuerySender : public TCPQuerySender
 {
 public:
-  DoHTCPCrossQuerySender(const ClientState& cs): d_cs(cs)
+  DoHTCPCrossQuerySender()
   {
   }
 
@@ -442,11 +442,6 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() const override
-  {
-    return &d_cs;
-  }
-
   void handleResponse(const struct timeval& now, TCPResponse&& response) override
   {
     if (!response.d_idstate.du) {
@@ -462,32 +457,40 @@ public:
     du->ids = std::move(response.d_idstate);
     DNSResponse dr(du->ids, du->response, du->downstream);
 
-    static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
-    static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
-
     dnsheader cleartextDH;
     memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
-    dr.ids.du = std::move(du);
+    if (!response.isAsync()) {
+      static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
+      static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
 
-    if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
-      if (dr.ids.du) {
-        dr.ids.du->status_code = 503;
-        sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules");
+      dr.ids.du = std::move(du);
+
+      if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
+        if (dr.ids.du) {
+          dr.ids.du->status_code = 503;
+          sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules");
+        }
+        return;
       }
-      return;
-    }
 
-    du = std::move(dr.ids.du);
+      if (dr.isAsynchronous()) {
+        return;
+      }
 
-    double udiff = du->ids.queryRealTime.udiff();
-    vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
+      du = std::move(dr.ids.du);
+    }
+
+    if (!du->ids.selfGenerated) {
+      double udiff = du->ids.queryRealTime.udiff();
+      vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
 
-    auto backendProtocol = du->downstream->getProtocol();
-    if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) {
-      backendProtocol = dnsdist::Protocol::DoTCP;
+      auto backendProtocol = du->downstream->getProtocol();
+      if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) {
+        backendProtocol = dnsdist::Protocol::DoTCP;
+      }
+      handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol);
     }
-    handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol);
 
     ++g_stats.responses;
     if (du->ids.cs) {
@@ -517,16 +520,23 @@ public:
     du->status_code = 502;
     sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response");
   }
-protected:
-  const ClientState& d_cs;
 };
 
 class DoHCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  DoHCrossProtocolQuery(DOHUnitUniquePtr&& du)
+  DoHCrossProtocolQuery(DOHUnitUniquePtr&& du, bool isResponse)
   {
-    query = InternalQuery(std::move(du->query), std::move(du->ids));
+    if (isResponse) {
+      /* happens when a response becomes async */
+      query = InternalQuery(std::move(du->response), std::move(du->ids));
+    }
+    else {
+      /* we need to duplicate the query here because we might need
+         the existing query later if we get a truncated answer */
+      query = InternalQuery(PacketBuffer(du->query), std::move(du->ids));
+    }
+
     /* it might have been moved when we moved du->ids */
     if (du) {
       query.d_idstate.du = std::move(du);
@@ -551,16 +561,61 @@ public:
   std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
   {
     query.d_idstate.du->downstream = downstream;
-    auto sender = std::make_shared<DoHTCPCrossQuerySender>(*query.d_idstate.cs);
-    return sender;
+    return s_sender;
+  }
+
+  DNSQuestion getDQ() override
+  {
+    auto& ids = query.d_idstate;
+    DNSQuestion dq(ids, query.d_buffer);
+    return dq;
   }
 
+  DNSResponse getDR() override
+  {
+    auto& ids = query.d_idstate;
+    DNSResponse dr(ids, query.d_buffer, downstream);
+    return dr;
+   }
+
   DOHUnitUniquePtr&& releaseDU()
   {
     return std::move(query.d_idstate.du);
   }
+
+private:
+  static std::shared_ptr<DoHTCPCrossQuerySender> s_sender;
 };
 
+std::shared_ptr<DoHTCPCrossQuerySender> DoHCrossProtocolQuery::s_sender = std::make_shared<DoHTCPCrossQuerySender>();
+
+std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse)
+{
+  if (!dq.ids.du) {
+    throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit");
+  }
+
+  auto du = std::move(dq.ids.du);
+  if (&dq.ids != &du->ids) {
+   du->ids = std::move(dq.ids);
+  }
+
+  du->ids.origID = dq.getHeader()->id;
+
+  if (!isResponse) {
+    if (du->query.data() != dq.getMutableData().data()) {
+      du->query = std::move(dq.getMutableData());
+    }
+  }
+  else {
+    if (du->response.data() != dq.getMutableData().data()) {
+      du->response = std::move(dq.getMutableData());
+    }
+  }
+
+  return std::make_unique<DoHCrossProtocolQuery>(std::move(du), isResponse);
+}
+
 /*
    We are not in the main DoH thread but in the DoH 'client' thread.
 */
@@ -650,6 +705,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       queryId = ntohs(dh->id);
     }
 
+    auto downstream = du->downstream;
     du->ids.qname = DNSName(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass);
     DNSQuestion dq(du->ids, du->query);
     const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
@@ -657,22 +713,24 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
     du->ids.cs = &cs;
     dq.sni = std::move(du->sni);
 
-    auto result = processQuery(dq, cs, holders, du->downstream);
+    auto result = processQuery(dq, holders, downstream);
 
     if (result == ProcessQueryResult::Drop) {
       du->status_code = 403;
       handleImmediateResponse(std::move(du), "DoH dropped query");
       return;
     }
-
-    if (result == ProcessQueryResult::SendAnswer) {
+    else if (result == ProcessQueryResult::Asynchronous) {
+      return;
+    }
+    else if (result == ProcessQueryResult::SendAnswer) {
       if (du->response.empty()) {
         du->response = std::move(du->query);
       }
       if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) {
         auto dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
 
-        handleResponseSent(ids.qname, QType(ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH);
+        handleResponseSent(du->ids.qname, QType(du->ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH);
       }
       handleImmediateResponse(std::move(du), "DoH self-answered response");
       return;
@@ -684,7 +742,6 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       return;
     }
 
-    auto downstream = du->downstream;
     if (downstream == nullptr) {
       du->status_code = 502;
       handleImmediateResponse(std::move(du), "DoH no backend available");
@@ -705,7 +762,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       du->tcp = true;
 
       /* this moves du->ids, careful! */
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du), false);
       cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
       if (downstream->passCrossProtocolQuery(std::move(cpq))) {
@@ -1302,7 +1359,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
       du->truncated = false;
       du->response.clear();
 
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du), false);
 
       if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
         continue;
@@ -1636,14 +1693,14 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse,
   const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
   if (!dh->tc) {
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
-    static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localcacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
+    static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
 
     DNSResponse dr(du->ids, du->response, du->downstream);
     dnsheader cleartextDH;
     memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
     dr.ids.du = std::move(du);
-    if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localcacheInsertedRespRuleActions, dr, false)) {
+    if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
       if (dr.ids.du) {
         dr.ids.du->status_code = 503;
         sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules");
@@ -1651,6 +1708,10 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse,
       return;
     }
 
+    if (dr.isAsynchronous()) {
+      return;
+    }
+
     du = std::move(dr.ids.du);
     double udiff = du->ids.queryRealTime.udiff();
     vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
diff --git a/pdns/dnsdistdist/test-dnsdistasync.cc b/pdns/dnsdistdist/test-dnsdistasync.cc
new file mode 100644 (file)
index 0000000..7cf9df5
--- /dev/null
@@ -0,0 +1,165 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnsdist-async.hh"
+
+BOOST_AUTO_TEST_SUITE(test_dnsdistasync)
+
+class DummyQuerySender : public TCPQuerySender
+{
+public:
+  bool active() const override
+  {
+    return true;
+  }
+
+  void handleResponse(const struct timeval&, TCPResponse&&) override
+  {
+  }
+
+  void handleXFRResponse(const struct timeval&, TCPResponse&&) override
+  {
+  }
+
+  void notifyIOError(InternalQueryState&&, const struct timeval&) override
+  {
+    errorRaised = true;
+  }
+
+  bool errorRaised{false};
+};
+
+struct DummyCrossProtocolQuery : public CrossProtocolQuery
+{
+  DummyCrossProtocolQuery() :
+    CrossProtocolQuery()
+  {
+    d_sender = std::make_shared<DummyQuerySender>();
+  }
+
+  std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
+  {
+    return d_sender;
+  }
+
+  std::shared_ptr<DummyQuerySender> d_sender;
+};
+
+BOOST_AUTO_TEST_CASE(test_Basic)
+{
+  auto holder = std::make_unique<dnsdist::AsynchronousHolder>();
+  BOOST_CHECK(holder->empty());
+
+  {
+    auto query = holder->get(0, 0);
+    BOOST_CHECK(query == nullptr);
+  }
+
+  {
+    uint16_t asyncID = 1;
+    uint16_t queryID = 42;
+    struct timeval ttd;
+    gettimeofday(&ttd, nullptr);
+    // timeout in 100 ms
+    ttd.tv_usec += 100000;
+
+    holder->push(asyncID, queryID, ttd, std::make_unique<DummyCrossProtocolQuery>());
+    BOOST_CHECK(!holder->empty());
+
+    auto query = holder->get(0, 0);
+    BOOST_CHECK(query == nullptr);
+
+    query = holder->get(asyncID, queryID);
+    BOOST_CHECK(holder->empty());
+
+    query = holder->get(asyncID, queryID);
+    BOOST_CHECK(query == nullptr);
+
+    // sleep for 200 ms, to be sure the main thread has
+    // been awakened
+    usleep(200000);
+  }
+
+  holder->stop();
+}
+
+BOOST_AUTO_TEST_CASE(test_TimeoutFailClose)
+{
+  auto holder = std::make_unique<dnsdist::AsynchronousHolder>(false);
+  uint16_t asyncID = 1;
+  uint16_t queryID = 42;
+  struct timeval ttd;
+  gettimeofday(&ttd, nullptr);
+  // timeout in 10 ms
+  ttd.tv_usec += 10000;
+
+  std::shared_ptr<DummyQuerySender> sender{nullptr};
+  {
+    auto query = std::make_unique<DummyCrossProtocolQuery>();
+    sender = query->d_sender;
+    BOOST_REQUIRE(sender != nullptr);
+    holder->push(asyncID, queryID, ttd, std::move(query));
+    BOOST_CHECK(!holder->empty());
+  }
+
+  // sleep for 20 ms, to be sure
+  usleep(20000);
+
+  BOOST_CHECK(holder->empty());
+  BOOST_CHECK(sender->errorRaised);
+
+  holder->stop();
+}
+
+BOOST_AUTO_TEST_CASE(test_AddingExpiredEvent)
+{
+  auto holder = std::make_unique<dnsdist::AsynchronousHolder>(false);
+  uint16_t asyncID = 1;
+  uint16_t queryID = 42;
+  struct timeval ttd;
+  gettimeofday(&ttd, nullptr);
+  // timeout was 10 ms ago, for some reason (long processing time, CPU starvation...)
+  ttd.tv_usec -= 10000;
+
+  std::shared_ptr<DummyQuerySender> sender{nullptr};
+  {
+    auto query = std::make_unique<DummyCrossProtocolQuery>();
+    sender = query->d_sender;
+    BOOST_REQUIRE(sender != nullptr);
+    holder->push(asyncID, queryID, ttd, std::move(query));
+    BOOST_CHECK(!holder->empty());
+  }
+
+  // sleep for 20 ms
+  usleep(20000);
+
+  BOOST_CHECK(holder->empty());
+  BOOST_CHECK(sender->errorRaised);
+
+  holder->stop();
+}
+
+BOOST_AUTO_TEST_SUITE_END();
index f9abd334f89e5662ca836f7b026f8a549a1aaf38..41d9992cda3bf8330a3546c7c15a876cf7393c9c 100644 (file)
@@ -599,11 +599,6 @@ public:
     return true;
   }
 
-  const ClientState* getClientState() const override
-  {
-    return nullptr;
-  }
-
   void handleResponse(const struct timeval& now, TCPResponse&& response) override
   {
     if (d_customHandler) {
index e427bf771740697fce9dd988f18b09de9f42e5c5..7da8b3153bdc86b96643eb76cf81d00474d4ba59 100644 (file)
@@ -63,12 +63,12 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo
 {
 }
 
-static std::function<ProcessQueryResult(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+static std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
 
-ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
+ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
   if (s_processQuery) {
-    return s_processQuery(dq, cs, holders, selectedBackend);
+    return s_processQuery(dq, selectedBackend);
   }
 
   return ProcessQueryResult::Drop;
@@ -496,7 +496,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::Drop;
     };
 
@@ -518,7 +518,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       // Would be nicer to actually turn it into a response
       return ProcessQueryResult::SendAnswer;
     };
@@ -550,7 +550,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       // Would be nicer to actually turn it into a response
       return ProcessQueryResult::SendAnswer;
     };
@@ -578,7 +578,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       throw std::runtime_error("Something unexpected happened");
     };
 
@@ -605,7 +605,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     s_steps.push_back({ ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 });
     s_steps.push_back({ ExpectedStep::ExpectedRequest::closeClient, IOState::Done });
 
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       // Would be nicer to actually turn it into a response
       return ProcessQueryResult::SendAnswer;
     };
@@ -627,7 +627,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, query.size() - 2 - 2 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       /* should not be reached */
       BOOST_CHECK(false);
       return ProcessQueryResult::SendAnswer;
@@ -665,7 +665,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::writeToClient, IOState::NeedWrite, 1 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
 
@@ -701,7 +701,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
       { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
 
@@ -759,7 +759,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
 
@@ -789,7 +789,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, s_proxyProtocolMinimumHeaderSize },
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
 
@@ -816,7 +816,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, proxyPayload.size() - s_proxyProtocolMinimumHeaderSize - 1},
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
 
@@ -895,7 +895,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing a connection to the backend */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -935,7 +935,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing a connection to the backend */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -974,7 +974,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing a connection to the backend */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1017,7 +1017,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing a connection to the backend */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1045,7 +1045,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing client connection */
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
-    s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       return ProcessQueryResult::SendAnswer;
     };
     s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool {
@@ -1082,7 +1082,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       /* closing backend connection */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1150,7 +1150,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1212,7 +1212,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
 
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
@@ -1248,7 +1248,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
 
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
@@ -1295,7 +1295,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1352,7 +1352,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1408,7 +1408,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1467,7 +1467,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1519,7 +1519,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1579,7 +1579,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1620,7 +1620,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1682,7 +1682,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     /* close the connection with the client */
     s_steps.push_back({ ExpectedStep::ExpectedRequest::closeClient, IOState::Done });
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1702,6 +1702,43 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
 #endif
   }
 
+  {
+    /* 2 queries on the same connection, asynchronously handled, check that we only read the first one (no OOOR as maxInFlight is 0) */
+    TEST_INIT("=> 2 queries on the same connection, async");
+
+    size_t count = 2;
+
+    s_readBuffer = query;
+
+    for (size_t idx = 0; idx < count; idx++) {
+      appendPayloadEditingID(s_readBuffer, query, idx);
+      appendPayloadEditingID(s_backendReadBuffer, query, idx);
+    }
+
+    s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 },
+                /* close the connection with the client */
+                { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }
+    };
+
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      selectedBackend = backend;
+      dq.asynchronous = true;
+      /* note that we do nothing with the query, we just tell the frontend it was dealt with */
+      return ProcessQueryResult::Asynchronous;
+    };
+    s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool {
+      return true;
+    };
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
+
+    /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
+    IncomingTCPConnectionState::clearAllDownstreamConnections();
+  }
 }
 
 BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
@@ -1871,7 +1908,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -1994,7 +2031,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend,&responses](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend,&responses](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       static size_t count = 0;
       if (count++ == 3) {
         /* self answered */
@@ -2183,7 +2220,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -2255,7 +2292,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     counter = 0;
-    s_processQuery = [backend,&counter](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend,&counter](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       if (counter == 0) {
         ++counter;
         selectedBackend = backend;
@@ -2338,7 +2375,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     counter = 0;
-    s_processQuery = [backend,&counter](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend,&counter](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       if (counter == 0) {
         ++counter;
         selectedBackend = backend;
@@ -2459,7 +2496,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -2611,7 +2648,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -2818,7 +2855,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -2992,7 +3029,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = proxyEnabledBackend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3256,7 +3293,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3382,7 +3419,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
     };
 
-    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = proxyEnabledBackend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3467,7 +3504,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
     };
 
-    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = proxyEnabledBackend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3532,7 +3569,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3723,7 +3760,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend1](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend1](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend1;
       return ProcessQueryResult::PassToBackend;
     };
@@ -3808,7 +3845,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -4040,7 +4077,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR)
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done },
     };
 
-    s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
       selectedBackend = backend;
       return ProcessQueryResult::PassToBackend;
     };
@@ -4063,6 +4100,65 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR)
     /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
     BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 5U);
   }
+
+  {
+    /* 2 queries on the same connection, asynchronously handled, check that we only read all of them (OOOR as maxInFlight is 65535) */
+    TEST_INIT("=> 2 queries on the same connection, async with OOOR");
+
+    size_t count = 2;
+
+    s_readBuffer = queries.at(0);
+
+    for (size_t idx = 0; idx < count; idx++) {
+      appendPayloadEditingID(s_readBuffer, queries.at(idx), idx);
+      appendPayloadEditingID(s_backendReadBuffer, queries.at(idx), idx);
+    }
+
+    bool timeout = false;
+    s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, queries.at(0).size() - 2 },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, queries.at(1).size() - 2 },
+                { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0, [&timeout](int desc) {
+                  timeout = true;
+                }},
+                /* close the connection with the client */
+                { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }
+    };
+
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      selectedBackend = backend;
+      dq.asynchronous = true;
+      /* note that we do nothing with the query, we just tell the frontend it was dealt with */
+      return ProcessQueryResult::Asynchronous;
+    };
+    s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool {
+      return true;
+    };
+
+    auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    IncomingTCPConnectionState::handleIO(state, now);
+    while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
+      threadData.mplexer->run(&now);
+    }
+
+    struct timeval later = now;
+    later.tv_sec += g_tcpRecvTimeout + 1;
+    auto expiredConns = threadData.mplexer->getTimeouts(later);
+    BOOST_CHECK_EQUAL(expiredConns.size(), 1U);
+    for (const auto& cbData : expiredConns) {
+      if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
+        auto cbState = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
+        cbState->handleTimeout(cbState, false);
+      }
+    }
+
+    BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
+
+    /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
+    IncomingTCPConnectionState::clearAllDownstreamConnections();
+  }
 }
 
 BOOST_AUTO_TEST_SUITE_END();
index 62e7f83d2960646687c2e3e5e34672b1cfcfa8df..325776bcdc8800af8e2786fbf66aaac1f82a097e 100644 (file)
@@ -188,6 +188,7 @@ struct DOHUnit
   void release()
   {
   }
+
   size_t proxyProtocolPayloadSize{0};
   uint16_t status_code{200};
 };
@@ -273,6 +274,11 @@ struct DOHUnit
 
 void handleUDPResponseForDoH(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&&, PacketBuffer&& response, InternalQueryState&& state);
 
+struct CrossProtocolQuery;
+struct DNSQuestion;
+
+std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse);
+
 #endif /* HAVE_DNS_OVER_HTTPS  */
 
 using DOHUnitUniquePtr = std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>;
index e8bd82988da84b0c3f1f303f925e18312430ba9c..b37bf28ae4e1379957cf26aa19a9ceb2ffeb1155 100644 (file)
@@ -288,7 +288,7 @@ public:
     return LockGuardedHolder<T>(d_value, d_mutex);
   }
 
-  LockGuardedHolder<const T> read_only_lock() const
+  LockGuardedHolder<const T> read_only_lock()
   {
     return LockGuardedHolder<const T>(d_value, d_mutex);
   }
index dacc24538762d5dbb28563881ae4c7328ba0617a..c4fe42b8aafe46586aa0b81b34c89d167cf88287 100644 (file)
@@ -27,6 +27,8 @@
 
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
+#include "dnsdist-internal-queries.hh"
+#include "dnsdist-tcp.hh"
 #include "dnsdist-xpf.hh"
 
 #include "dolog.hh"
 #include "ednscookies.hh"
 #include "ednssubnet.hh"
 
-bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&)
+ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
+{
+  return ProcessQueryResult::Drop;
+}
+
+bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted)
+{
+  return false;
+}
+
+bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   return false;
 }
@@ -47,6 +59,18 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
   return false;
 }
 
+namespace dnsdist {
+std::unique_ptr<CrossProtocolQuery> getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse)
+{
+  return nullptr;
+}
+}
+
+bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&)
+{
+  return false;
+}
+
 BOOST_AUTO_TEST_SUITE(test_dnsdist_cc)
 
 static const uint16_t ECSSourcePrefixV4 = 24;