]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Stronger guarantees against data race in the UDP path
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 28 Sep 2022 15:21:16 +0000 (17:21 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 11 Jan 2023 11:28:20 +0000 (12:28 +0100)
pdns/dnsdist-idstate.hh
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/doh.cc
pdns/test-dnsdist_cc.cc

index 1932301ad05201d2b0b63dc95187a628f6160e5d..1fb8116be16af15016b9c326f5ecfef6e3905415 100644 (file)
@@ -162,73 +162,25 @@ struct IDState
   IDState(const IDState& orig) = delete;
   IDState(IDState&& rhs)
   {
-    if (rhs.isInUse()) {
-      throw std::runtime_error("Trying to move an in-use IDState");
-    }
-
-#ifdef __SANITIZE_THREAD__
+    inUse.store(rhs.inUse.load());
     age.store(rhs.age.load());
-#else
-    age = rhs.age;
-#endif
     internal = std::move(rhs.internal);
   }
 
   IDState& operator=(IDState&& rhs)
   {
-    if (isInUse()) {
-      throw std::runtime_error("Trying to overwrite an in-use IDState");
-    }
-
-    if (rhs.isInUse()) {
-      throw std::runtime_error("Trying to move an in-use IDState");
-    }
-#ifdef __SANITIZE_THREAD__
+    inUse.store(rhs.inUse.load());
     age.store(rhs.age.load());
-#else
-    age = rhs.age;
-#endif
-
     internal = std::move(rhs.internal);
-
     return *this;
   }
 
-  static const int64_t unusedIndicator = -1;
-
-  static bool isInUse(int64_t usageIndicator)
-  {
-    return usageIndicator != unusedIndicator;
-  }
-
   bool isInUse() const
   {
-    return usageIndicator != unusedIndicator;
-  }
-
-  /* return true if the value has been successfully replaced meaning that
-     no-one updated the usage indicator in the meantime */
-  bool tryMarkUnused(int64_t expectedUsageIndicator)
-  {
-    return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator);
-  }
-
-  /* mark as used no matter what, return true if the state was in use before */
-  bool markAsUsed()
-  {
-    auto currentGeneration = generation++;
-    return markAsUsed(currentGeneration);
-  }
-
-  /* mark as used no matter what, return true if the state was in use before */
-  bool markAsUsed(int64_t currentGeneration)
-  {
-    int64_t oldUsage = usageIndicator.exchange(currentGeneration);
-    return oldUsage != unusedIndicator;
+    return inUse == true;
   }
 
-  /* We use this value to detect whether this state is in use.
-     For performance reasons we don't want to use a lock here, but that means
+  /* For performance reasons we don't want to use a lock here, but that means
      we need to be very careful when modifying this value. Modifications happen
      from:
      - one of the UDP or DoH 'client' threads receiving a query, selecting a backend
@@ -246,26 +198,52 @@ struct IDState
        the corresponding state and sending the response to the client ;
      - the 'healthcheck' thread scanning the states to actively discover timeouts,
        mostly to keep some counters like the 'outstanding' one sane.
-     We previously based that logic on the origFD (FD on which the query was received,
-     and therefore from where the response should be sent) but this suffered from an
-     ABA problem since it was quite likely that a UDP 'client thread' would reset it to the
-     same value since we only have so much incoming sockets:
-     - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ;
-     - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname,
-       qtype and qclass match
-     - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ;
-     - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still
-       5, except it's not the same 5 anymore and it overrides a fresh state.
-     We now use a 32-bit unsigned counter instead, which is incremented every time the state is set,
-     wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1
-     when the state is unused and the value of our counter otherwise.
+
+     We have two flags:
+     - inUse tells us if there currently is a in-flight query whose state is stored
+       in this state
+     - locked tells us whether someone currently owns the state, so no-one else can touch
+       it
   */
   InternalQueryState internal;
-  std::atomic<int64_t> usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty   // 8
-  std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue    // 4
-#ifdef __SANITIZE_THREAD__
   std::atomic<uint16_t> age{0};
-#else
-  uint16_t age{0}; // 2
-#endif
+
+  class StateGuard
+  {
+  public:
+    StateGuard(IDState& ids) :
+      d_ids(ids)
+    {
+    }
+    ~StateGuard()
+    {
+      d_ids.release();
+    }
+    StateGuard(const StateGuard&) = delete;
+    StateGuard(StateGuard&&) = delete;
+    StateGuard& operator=(const StateGuard&) = delete;
+    StateGuard& operator=(StateGuard&&) = delete;
+
+  private:
+    IDState& d_ids;
+  };
+
+  [[nodiscard]] std::optional<StateGuard> acquire()
+  {
+    bool expected = false;
+    if (locked.compare_exchange_strong(expected, true)) {
+      return std::optional<StateGuard>(*this);
+    }
+    return std::nullopt;
+  }
+
+  void release()
+  {
+    locked.store(false);
+  }
+
+  std::atomic<bool> inUse{false}; // 1
+
+private:
+  std::atomic<bool> locked{false}; // 1
 };
index 6b751c281fecd29a8e86073618aebdb6706d808a..0b6aa516f346fc69726ad6bfbba46221aa148859 100644 (file)
@@ -634,7 +634,7 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff,
   doLatencyStats(incomingProtocol, udiff);
 }
 
-static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated, std::optional<uint16_t> queryId)
+static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated)
 {
   DNSResponse dr(ids, response, ds);
 
@@ -653,9 +653,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
   if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
-    if (queryId) {
-      ds->releaseState(*queryId);
-    }
     return;
   }
 
@@ -686,10 +683,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   else {
     handleResponseSent(ids, 0., dr.ids.origRemote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP);
   }
-
-  if (queryId) {
-    ds->releaseState(*queryId);
-  }
 }
 
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
@@ -728,57 +721,23 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
         queryId = dh->id;
 
-        IDState* ids = dss->getExistingState(queryId);
-        if (ids == nullptr) {
+        auto ids = dss->getState(queryId);
+        if (!ids) {
           continue;
         }
 
-        int64_t usageIndicator = ids->usageIndicator;
-
-        if (!IDState::isInUse(usageIndicator)) {
-          /* the corresponding state is marked as not in use, meaning that:
-             - it was already cleaned up by another thread and the state is gone ;
-             - we already got a response for this query and this one is a duplicate.
-             Either way, we don't touch it.
-          */
-          continue;
-        }
-
-        /* setting age to 0 to prevent the maintainer thread from
-           cleaning this IDS while we process the response.
-        */
-        ids->age = 0;
-
         unsigned int qnameWireLength = 0;
-        if (fd != ids->internal.backendFD || !responseContentMatches(response, ids->internal.qname, ids->internal.qtype, ids->internal.qclass, dss, qnameWireLength)) {
+        if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
+          dss->restoreState(queryId, std::move(*ids));
           continue;
         }
 
-        DOHUnitUniquePtr du(nullptr, DOHUnit::release);
-        /* atomically mark the state as available, but only if it has not been altered
-           in the meantime */
-        if (ids->tryMarkUnused(usageIndicator)) {
-          /* clear the potential DOHUnit asap, it's ours now
-           and since we just marked the state as unused,
-           someone could overwrite it. */
-          du = std::move(ids->internal.du);
-          /* we only decrement the outstanding counter if the value was not
-             altered in the meantime, which would mean that the state has been actively reused
-             and the other thread has not incremented the outstanding counter, so we don't
-             want it to be decremented twice. */
-          --dss->outstanding;  // you'd think an attacker could game this, but we're using connected socket
-        } else {
-          /* someone updated the state in the meantime, we can't touch the existing pointer */
-          du.release();
-          /* since the state has been updated, we can't safely access it so let's just drop
-             this response */
-          continue;
-        }
+        auto du = std::move(ids->du);
 
-        dh->id = ids->internal.origID;
+        dh->id = ids->origID;
         ++dss->responses;
 
-        double udiff = ids->internal.queryRealTime.udiff();
+        double udiff = ids->queryRealTime.udiff();
         // do that _before_ the processing, otherwise it's not fair to the backend
         dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
         dss->reportResponse(dh->rcode);
@@ -787,13 +746,12 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
           // DoH query, we cannot touch du after that
-          handleUDPResponseForDoH(std::move(du), std::move(response), std::move(ids->internal));
+          handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids));
 #endif
-          dss->releaseState(queryId);
           continue;
         }
 
-        handleResponseForUDPClient(ids->internal, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId);
+        handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false);
       }
     }
     catch (const std::exception& e) {
@@ -1445,7 +1403,7 @@ public:
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
 
-    handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt);
+    handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated);
   }
 
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
@@ -1487,62 +1445,55 @@ public:
   }
 };
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
 {
   bool doh = dq.ids.du != nullptr;
-  unsigned int idOffset = 0;
-  int64_t generation;
-  IDState* ids = ds->getIDState(idOffset, generation);
-
-  dq.getHeader()->id = idOffset;
 
   bool failed = false;
+  size_t proxyPayloadSize = 0;
   if (ds->d_config.useProxyProtocol) {
     try {
-      size_t payloadSize = 0;
-      if (addProxyProtocol(dq, &payloadSize)) {
+      if (addProxyProtocol(dq, &proxyPayloadSize)) {
         if (dq.ids.du) {
-          dq.ids.du->proxyProtocolPayloadSize = payloadSize;
+          dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize;
         }
       }
     }
     catch (const std::exception& e) {
       vinfolog("Adding proxy protocol payload to %squery from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what());
-      failed = true;
+      return false;
     }
   }
 
   try {
-    if (!failed) {
-      int fd = ds->pickSocketForSending();
-      dq.ids.backendFD = fd;
-      dq.ids.origID = queryID;
-      dq.ids.forwardedOverUDP = true;
-      ids->internal = std::move(dq.ids);
+    int fd = ds->pickSocketForSending();
+    dq.ids.backendFD = fd;
+    dq.ids.origID = queryID;
+    dq.ids.forwardedOverUDP = true;
 
-      vinfolog("Got query for %s|%s from %s%s, relayed to %s", ids->internal.qname.toLogString(), QType(ids->internal.qtype).toString(), ids->internal.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());
-      /* you can't touch du after this line, unless the call returned a non-negative value,
-         because it might already have been freed */
-      ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
+    vinfolog("Got query for %s|%s from %s%s, relayed to %s", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());
 
-      if (ret < 0) {
-        failed = true;
-      }
-    }
-    else {
-      ids->internal = std::move(dq.ids);
+    auto idOffset = ds->saveState(std::move(dq.ids));
+    /* set the correct ID */
+    memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset));
+
+    /* you can't touch ids or du after this line, unless the call returned a non-negative value,
+       because it might already have been freed */
+    ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
+
+    if (ret < 0) {
+      failed = true;
     }
 
     if (failed) {
-      /* we are about to handle the error, make sure that
-         this pointer is not accessed when the state is cleaned,
-         but first check that it still belongs to us */
-      if (ids->tryMarkUnused(generation) && ids->internal.du) {
-        dq.ids.du = std::move(ids->internal.du);
-        --ds->outstanding;
-      }
-      if (dq.ids.du) {
-        dq.ids.du->status_code = 502;
+      /* clear up the state. In the very unlikely event it was reused
+         in the meantime, so be it. */
+      auto cleared = ds->getState(idOffset);
+      if (cleared) {
+        dq.ids.du = std::move(cleared->du);
+        if (dq.ids.du) {
+          dq.ids.du->status_code = 502;
+        }
       }
       ++g_stats.downstreamSendErrors;
       ++ds->sendErrors;
@@ -1667,7 +1618,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       return;
     }
 
-    assignOutgoingUDPQueryToBackend(ss, dh->id, dq, std::move(query), dest);
+    assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest);
   }
   catch(const std::exception& e){
     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what());
index 7c8a3ec3641d16b308f81ab2f3776b5286d8becd..a425018e6ecd0af4ddf9ed83c69460b5cbf94c86 100644 (file)
@@ -1002,13 +1002,13 @@ public:
   int pickSocketForSending();
   void pickSocketsReadyForReceiving(std::vector<int>& ready);
   void handleUDPTimeouts();
-  IDState* getIDState(unsigned int& id, int64_t& generation);
-  IDState* getExistingState(unsigned int id);
-  void releaseState(unsigned int id);
   void reportTimeoutOrError();
   void reportResponse(uint8_t rcode);
   void submitHealthCheckResult(bool initial, bool newState);
   time_t getNextLazyHealthCheck();
+  uint16_t saveState(InternalQueryState&&);
+  void restoreState(uint16_t id, InternalQueryState&&);
+  std::optional<InternalQueryState> getState(uint16_t id);
 
   dnsdist::Protocol getProtocol() const
   {
@@ -1206,7 +1206,7 @@ static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses la
 enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend };
 ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest);
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest);
 
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
 void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol);
index daff23d6aa313b9cbebb160de829d9e0c9e9c048..eea141e440509a02dd149d373910882b5bb788f7 100644 (file)
@@ -321,19 +321,16 @@ bool DownstreamState::s_randomizeSockets{false};
 bool DownstreamState::s_randomizeIDs{false};
 int DownstreamState::s_udpTimeout{2};
 
-static bool isIDSExpired(IDState& ids)
+static bool isIDSExpired(const IDState& ids)
 {
-  auto age = ids.age++;
+  auto age = ids.age.load();
   return age > DownstreamState::s_udpTimeout;
 }
 
 void DownstreamState::handleUDPTimeout(IDState& ids)
 {
-  /* We mark the state as unused as soon as possible
-     to limit the risk of racing with the
-     responder thread.
-  */
   ids.age = 0;
+  ids.inUse = false;
   handleDOHTimeout(std::move(ids.internal.du));
   reuseds++;
   --outstanding;
@@ -386,20 +383,26 @@ void DownstreamState::handleUDPTimeouts()
         it = map->erase(it);
         continue;
       }
+      ++ids.age;
       ++it;
     }
   }
   else {
     if (outstanding.load() > 0) {
       for (IDState& ids : idStates) {
-        int64_t usageIndicator = ids.usageIndicator;
-        if (IDState::isInUse(usageIndicator) && isIDSExpired(ids)) {
-          if (!ids.tryMarkUnused(usageIndicator)) {
-            /* this state has been altered in the meantime,
-               don't go anywhere near it */
-            continue;
-          }
-
+        if (!ids.isInUse()) {
+          continue;
+        }
+        if (!isIDSExpired(ids)) {
+          ++ids.age;
+          continue;
+        }
+        auto guard = ids.acquire();
+        if (!guard) {
+          continue;
+        }
+        /* check again, now that we have locked this state */
+        if (ids.isInUse() && isIDSExpired(ids)) {
           handleUDPTimeout(ids);
         }
       }
@@ -407,89 +410,143 @@ void DownstreamState::handleUDPTimeouts()
   }
 }
 
-IDState* DownstreamState::getExistingState(unsigned int stateId)
+uint16_t DownstreamState::saveState(InternalQueryState&& state)
 {
   if (s_randomizeIDs) {
+    /* if the state is already in use we will retry,
+       up to 5 five times. The last selected one is used
+       even if it was already in use */
+    size_t remainingAttempts = 5;
     auto map = d_idStatesMap.lock();
-    auto it = map->find(stateId);
-    if (it == map->end()) {
-      return nullptr;
+
+    do {
+      uint16_t selectedID = dnsdist::getRandomValue(std::numeric_limits<uint16_t>::max());
+      auto [it, inserted] = map->emplace(selectedID, IDState());
+
+      if (!inserted) {
+        remainingAttempts--;
+        if (remainingAttempts > 0) {
+          continue;
+        }
+
+        auto oldDU = std::move(it->second.internal.du);
+        ++reuseds;
+        ++g_stats.downstreamTimeouts;
+        handleDOHTimeout(std::move(oldDU));
+      }
+      else {
+        ++outstanding;
+      }
+
+      it->second.internal = std::move(state);
+      it->second.age.store(0);
+
+      return it->first;
     }
-    return &it->second;
+    while (true);
   }
-  else {
-    if (stateId >= idStates.size()) {
-      return nullptr;
+
+  do {
+    IDState* ids = nullptr;
+    uint16_t selectedID = (idOffset++) % idStates.size();
+    ids = &idStates[selectedID];
+    auto guard = ids->acquire();
+    if (!guard) {
+      continue;
+    }
+    if (ids->isInUse()) {
+      /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
+         to handle it because it's about to be overwritten. */
+      auto oldDU = std::move(ids->internal.du);
+      ++reuseds;
+      ++g_stats.downstreamTimeouts;
+      handleDOHTimeout(std::move(oldDU));
     }
-    return &idStates[stateId];
+    else {
+      ++outstanding;
+    }
+    ids->internal = std::move(state);
+    ids->age.store(0);
+    ids->inUse = true;
+    return selectedID;
   }
+  while (true);
 }
 
-void DownstreamState::releaseState(unsigned int stateId)
+void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state)
 {
   if (s_randomizeIDs) {
     auto map = d_idStatesMap.lock();
-    auto it = map->find(stateId);
-    if (it == map->end()) {
-      return;
+
+    auto [it, inserted] = map->emplace(id, IDState());
+    if (!inserted) {
+      /* already used */
+      ++reuseds;
+      ++g_stats.downstreamTimeouts;
+      handleDOHTimeout(std::move(state.du));
     }
-    if (it->second.isInUse()) {
-      return;
+    else {
+      it->second.internal = std::move(state);
+      ++outstanding;
     }
-    map->erase(it);
+    return;
   }
+
+  auto& ids = idStates[id];
+  auto guard = ids.acquire();
+  if (!guard) {
+    /* already used */
+    ++reuseds;
+    ++g_stats.downstreamTimeouts;
+    handleDOHTimeout(std::move(state.du));
+    return;
+  }
+  if (ids.isInUse()) {
+    /* already used */
+    ++reuseds;
+    ++g_stats.downstreamTimeouts;
+    handleDOHTimeout(std::move(state.du));
+    return;
+  }
+  ids.internal = std::move(state);
+  ids.inUse = true;
+  ++outstanding;
 }
 
-IDState* DownstreamState::getIDState(unsigned int& selectedID, int64_t& generation)
+std::optional<InternalQueryState> DownstreamState::getState(uint16_t id)
 {
-  DOHUnitUniquePtr du(nullptr, DOHUnit::release);
-  IDState* ids = nullptr;
+  std::optional<InternalQueryState> result = std::nullopt;
+
   if (s_randomizeIDs) {
-    /* if the state is already in use we will retry,
-       up to 5 five times. The last selected one is used
-       even if it was already in use */
-    size_t remainingAttempts = 5;
     auto map = d_idStatesMap.lock();
 
-    bool done = false;
-    do {
-      selectedID = dnsdist::getRandomValue(std::numeric_limits<uint16_t>::max());
-      auto [it, inserted] = map->insert({selectedID, IDState()});
-      ids = &it->second;
-      if (inserted) {
-        done = true;
-      }
-      else {
-        remainingAttempts--;
-      }
+    auto it = map->find(id);
+    if (it == map->end()) {
+      return result;
     }
-    while (!done && remainingAttempts > 0);
-  }
-  else {
-    selectedID = (idOffset++) % idStates.size();
-    ids = &idStates[selectedID];
-  }
 
-  ids->age = 0;
+    result = std::move(it->second.internal);
+    map->erase(it);
+    --outstanding;
+    return result;
+  }
 
-  /* we atomically replace the value, we now own this state */
-  generation = ids->generation++;
-  if (!ids->markAsUsed(generation)) {
-    /* the state was not in use.
-       we reset 'du' because it might have still been in use when we read it. */
-    du.release();
-    ++outstanding;
+  if (id > idStates.size()) {
+    return result;
   }
-  else {
-    /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
-       to handle it because it's about to be overwritten. */
-    auto oldDU = std::move(ids->internal.du);
-    ++reuseds;
-    ++g_stats.downstreamTimeouts;
-    handleDOHTimeout(std::move(oldDU));
+
+  auto& ids = idStates[id];
+  auto guard = ids.acquire();
+  if (!guard) {
+    return result;
   }
 
-  return ids;
+  if (ids.isInUse()) {
+    result = std::move(ids.internal);
+    --outstanding;
+  }
+  ids.inUse = false;
+  return result;
 }
 
 bool DownstreamState::healthCheckRequired(std::optional<time_t> currentTime)
index d6eb18bbc9fe40f673074b7784e753d351e22446..b66836b3675b949852fd5494da22fac4f93c1214 100644 (file)
@@ -677,7 +677,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit)
     }
 
     ComboAddress dest = dq.ids.origDest;
-    if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, std::move(du->query), dest)) {
+    if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) {
             sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
       return;
     }
index 71e140b99c25090f74846ff3469eb9325ce57f28..dacc24538762d5dbb28563881ae4c7328ba0617a 100644 (file)
@@ -42,7 +42,7 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&)
   return false;
 }
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
 {
   return false;
 }