]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor load-balancing policies
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 21 Jul 2025 15:04:40 +0000 (17:04 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 6 Oct 2025 14:50:24 +0000 (16:50 +0200)
Since we no longer need to increase the reference counter of the
returned backend (the runtime configuration cannot be updated be
updated under our feet anymore), we can return the position of the
selected backend in the initial array instead, significantly
reducing the performance cost of the load-balancing policies.

Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-lbpolicies.cc
pdns/dnsdistdist/dnsdist-lbpolicies.hh
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist.cc
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
regression-tests.dnsdist/test_Routing.py

index 6857c6fd7f9b78e4057c73a5f0fb5bd70aa7649f..a493f453a22bc892087a3da9539f6d8b7c2b4822 100644 (file)
@@ -244,6 +244,7 @@ dnsdist_SOURCES = \
        dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-secpoll.cc dnsdist-secpoll.hh \
        dnsdist-self-answers.cc dnsdist-self-answers.hh \
+       dnsdist-server-pool.hh \
        dnsdist-session-cache.cc dnsdist-session-cache.hh \
        dnsdist-snmp.cc dnsdist-snmp.hh \
        dnsdist-svc.cc dnsdist-svc.hh \
index 58075c41a50ef4388a47d6bb108fa4a24cf3da8a..f1de02ea1877bec345bb532b431e85f005c5bd06 100644 (file)
@@ -31,7 +31,7 @@ static constexpr size_t s_staticArrayCutOff = 16;
 template <typename T> using DynamicIndexArray = std::vector<std::pair<T, size_t>>;
 template <typename T> using StaticIndexArray = std::array<std::pair<T, size_t>, s_staticArrayCutOff>;
 
-template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss)
+template <class T> static std::optional<ServerPolicy::SelectedServerPosition> getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss)
 {
   /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
      which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
@@ -44,22 +44,21 @@ template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(c
   }
 
   if (usableServers == 0) {
-    return shared_ptr<DownstreamState>();
+    return std::nullopt;
   }
 
   std::nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; });
-  // minus 1 because the NumberedServerVector starts at 1 for Lua
-  return servers.at(poss.begin()->second - 1).second;
+  return poss.begin()->second;
 }
 
 // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
-shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   (void)dq;
   using LeastOutstandingType = std::tuple<int,int,double>;
 
   if (servers.size() == 1 && servers[0].second->isUp()) {
-    return servers[0].second;
+    return 1;
   }
 
   if (servers.size() <= s_staticArrayCutOff) {
@@ -72,17 +71,17 @@ shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerV
   return getLeastOutstanding(servers, poss);
 }
 
-shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   for (auto& d : servers) {
     if (d.second->isUp() && (!d.second->d_qpsLimiter || d.second->d_qpsLimiter->checkOnly())) {
-      return d.second;
+      return d.first;
     }
   }
   return leastOutstanding(servers, dq);
 }
 
-template <class T> static std::shared_ptr<DownstreamState> getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad)
+template <class T> static std::optional<ServerPolicy::SelectedServerPosition> getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad)
 {
   constexpr int max = std::numeric_limits<int>::max();
   int sum = 0;
@@ -105,20 +104,19 @@ template <class T> static std::shared_ptr<DownstreamState> getValRandom(const Se
 
   // Catch the case where usableServers or sum are equal to 0 to avoid a SIGFPE
   if (usableServers == 0 || sum == 0) {
-    return shared_ptr<DownstreamState>();
+    return std::nullopt;
   }
 
   int r = val % sum;
   auto p = std::upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return  r_ < a.first;});
   if (p == poss.begin() + usableServers) {
-    return shared_ptr<DownstreamState>();
+    return std::nullopt;
   }
 
-  // minus 1 because the NumberedServerVector starts at 1 for Lua
-  return servers.at(p->second - 1).second;
+  return p->second;
 }
 
-static shared_ptr<DownstreamState> valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers)
+static std::optional<ServerPolicy::SelectedServerPosition> valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers)
 {
   using ValRandomType = int;
   double targetLoad = std::numeric_limits<double>::max();
@@ -149,28 +147,28 @@ static shared_ptr<DownstreamState> valrandom(const unsigned int val, const Serve
   return getValRandom(servers, poss, val, targetLoad);
 }
 
-shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   (void)dq;
   return valrandom(dns_random_uint32(), servers);
 }
 
-shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
+std::optional<ServerPolicy::SelectedServerPosition> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
 {
   return valrandom(hash, servers);
 }
 
-shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation;
   return whashedFromHash(servers, dq->ids.qname.hash(hashPerturbation));
 }
 
-shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
+std::optional<ServerPolicy::SelectedServerPosition> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
 {
   unsigned int sel = std::numeric_limits<unsigned int>::max();
   unsigned int min = std::numeric_limits<unsigned int>::max();
-  shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
+  std::optional<ServerPolicy::SelectedServerPosition> ret, first;
 
   double targetLoad = std::numeric_limits<double>::max();
   const auto consistentHashBalancingFactor = dnsdist::configuration::getImmutableConfiguration().d_consistentHashBalancingFactor;
@@ -197,44 +195,45 @@ shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVe
         d.second->hash();
       }
       {
+        const auto position = d.first;
         const auto& server = d.second;
         auto hashes = server->hashes.read_lock();
         // we want to keep track of the last hash
         if (min > *(hashes->begin())) {
           min = *(hashes->begin());
-          first = server;
+          first = position;
         }
 
         auto hash_it = std::lower_bound(hashes->begin(), hashes->end(), qhash);
         if (hash_it != hashes->end()) {
           if (*hash_it < sel) {
             sel = *hash_it;
-            ret = server;
+            ret = position;
           }
         }
       }
     }
   }
-  if (ret != nullptr) {
+  if (ret) {
     return ret;
   }
-  if (first != nullptr) {
+  if (first) {
     return first;
   }
-  return shared_ptr<DownstreamState>();
+  return std::nullopt;
 }
 
-shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation;
   return chashedFromHash(servers, dq->ids.qname.hash(hashPerturbation));
 }
 
-shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
   (void)dq;
   if (servers.empty()) {
-    return shared_ptr<DownstreamState>();
+    return std::nullopt;
   }
 
   vector<size_t> candidates;
@@ -248,7 +247,7 @@ shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector&
 
   if (candidates.empty()) {
     if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_roundrobinFailOnNoServer) {
-      return shared_ptr<DownstreamState>();
+      return std::nullopt;
     }
     for (auto& d : servers) {
       candidates.push_back(d.first);
@@ -256,17 +255,19 @@ shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector&
   }
 
   static unsigned int counter;
-  return servers.at(candidates.at((counter++) % candidates.size()) - 1).second;
+  return candidates.at((counter++) % candidates.size());
 }
 
-shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq)
+std::optional<ServerPolicy::SelectedServerPosition> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq)
 {
   if (servers.empty()) {
-    return {};
+    return std::nullopt;
   }
 
   ServerPolicy::NumberedServerVector candidates;
   candidates.reserve(servers.size());
+  std::vector<ServerPolicy::SelectedServerPosition> positionsMap;
+  positionsMap.reserve(servers.size());
 
   int curOrder = std::numeric_limits<int>::max();
   unsigned int curNumber = 1;
@@ -279,14 +280,19 @@ shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServer
       }
       curOrder = svr.second->d_config.order;
       candidates.push_back(ServerPolicy::NumberedServer(curNumber++, svr.second));
+      positionsMap.push_back(svr.first);
     }
   }
 
   if (candidates.empty()) {
-    return {};
+    return std::nullopt;
   }
 
-  return wrandom(candidates, dnsq);
+  auto selected = wrandom(candidates, dnsq);
+  if (selected) {
+    return positionsMap.at(*selected - 1);
+  }
+  return selected;
 }
 
 const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName)
@@ -406,42 +412,51 @@ const ServerPolicy::ffipolicyfunc_t& ServerPolicy::getPerThreadPolicy() const
   return state->d_policies.at(d_name);
 }
 
-std::shared_ptr<DownstreamState> ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const
+ServerPolicy::SelectedBackend ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const
 {
-  std::shared_ptr<DownstreamState> selectedBackend{nullptr};
+  ServerPolicy::SelectedBackend result{servers};
 
   if (d_isLua) {
     if (!d_isFFI) {
-      auto lock = g_lua.lock();
-      selectedBackend = d_policy(servers, &dq);
-    }
-    else {
-      dnsdist_ffi_dnsquestion_t dnsq(&dq);
-      dnsdist_ffi_servers_list_t serversList(servers);
-      unsigned int selected = 0;
-
-      if (!d_isPerThread) {
+      std::optional<SelectedServerPosition> position;
+      {
         auto lock = g_lua.lock();
-        selected = d_ffipolicy(&serversList, &dnsq);
+        position = d_policy(servers, &dq);
       }
-      else {
-        const auto& policy = getPerThreadPolicy();
-        selected = policy(&serversList, &dnsq);
+      if (position && *position > 0 && *position <= servers.size()) {
+        result.setSelected(*position - 1);
       }
+      return result;
+    }
 
-      if (selected >= servers.size()) {
-        /* invalid offset, meaning that there is no server available */
-        return {};
-      }
+    dnsdist_ffi_dnsquestion_t dnsq(&dq);
+    dnsdist_ffi_servers_list_t serversList(servers);
+    ServerPolicy::SelectedServerPosition selected = 0;
 
-      selectedBackend = servers.at(selected).second;
+    if (!d_isPerThread) {
+      auto lock = g_lua.lock();
+      selected = d_ffipolicy(&serversList, &dnsq);
+    }
+    else {
+      const auto& policy = getPerThreadPolicy();
+      selected = policy(&serversList, &dnsq);
+    }
+
+    if (selected >= servers.size()) {
+      /* invalid offset, meaning that there is no server available */
+      return result;
     }
+
+    result.setSelected(selected);
+    return result;
   }
-  else {
-    selectedBackend = d_policy(servers, &dq);
+
+  auto position = d_policy(servers, &dq);
+  if (position && *position > 0 && *position <= servers.size()) {
+    result.setSelected(*position - 1);
   }
 
-  return selectedBackend;
+  return result;
 }
 
 namespace dnsdist::lbpolicies
index a6ea89f706a396882fe60f63884b53b28836238a..554b4adaa2798b2d62abb5052bda783ae1a0fa62 100644 (file)
@@ -21,6 +21,9 @@
  */
 #pragma once
 
+#include <memory>
+#include <optional>
+
 struct dnsdist_ffi_servers_list_t;
 struct dnsdist_ffi_server_t;
 struct dnsdist_ffi_dnsquestion_t;
@@ -33,14 +36,15 @@ struct PerThreadPoliciesState;
 class ServerPolicy
 {
 public:
+  using SelectedServerPosition = unsigned int;
   template <class T>
   using Numbered = std::pair<unsigned int, T>;
-  using NumberedServer = Numbered<shared_ptr<DownstreamState>>;
+  using NumberedServer = Numbered<std::shared_ptr<DownstreamState>>;
   template <class T>
   using NumberedVector = std::vector<std::pair<unsigned int, T>>;
-  using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
-  using policyfunc_t = std::function<std::shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)>;
-  using ffipolicyfunc_t = std::function<unsigned int(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)>;
+  using NumberedServerVector = NumberedVector<std::shared_ptr<DownstreamState>>;
+  using policyfunc_t = std::function<std::optional<SelectedServerPosition>(const NumberedServerVector& servers, const DNSQuestion*)>;
+  using ffipolicyfunc_t = std::function<SelectedServerPosition(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)>;
 
   ServerPolicy(const std::string& name_, policyfunc_t policy_, bool isLua_) :
     d_name(name_), d_policy(std::move(policy_)), d_isLua(isLua_)
@@ -59,7 +63,43 @@ public:
   {
   }
 
-  std::shared_ptr<DownstreamState> getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const;
+  class SelectedBackend
+  {
+  public:
+    SelectedBackend(const NumberedServerVector& backends) :
+      d_backends(&backends)
+    {
+    }
+
+    void setSelected(SelectedServerPosition selected)
+    {
+      if (selected >= d_backends->size()) {
+        throw std::runtime_error("Setting an invalid backend position (" + std::to_string(selected) + " out of " + std::to_string(d_backends->size()) + ") from the server policy");
+      }
+      d_selected = selected;
+    }
+
+    operator bool() const noexcept
+    {
+      return d_selected.has_value();
+    }
+
+    DownstreamState* operator->() const noexcept
+    {
+      return (*d_backends)[*d_selected].second.get();
+    }
+
+    const std::shared_ptr<DownstreamState>& get() const noexcept
+    {
+      return (*d_backends)[*d_selected].second;
+    }
+
+  private:
+    const NumberedServerVector* d_backends{nullptr};
+    std::optional<SelectedServerPosition> d_selected;
+  };
+
+  SelectedBackend getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const;
 
   const std::string& getName() const
   {
@@ -68,7 +108,7 @@ public:
 
   std::string toString() const
   {
-    return string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\"";
+    return std::string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\"";
   }
 
 private:
@@ -93,23 +133,22 @@ struct ServerPool;
 
 using pools_t = std::map<std::string, std::shared_ptr<ServerPool>>;
 const ServerPool& getPool(const std::string& poolName);
-const ServerPool& createPoolIfNotExists(const string& poolName);
-void setPoolPolicy(const string& poolName, std::shared_ptr<ServerPolicy> policy);
-void addServerToPool(const string& poolName, std::shared_ptr<DownstreamState> server);
-void removeServerFromPool(const string& poolName, std::shared_ptr<DownstreamState> server);
+const ServerPool& createPoolIfNotExists(const std::string& poolName);
+void setPoolPolicy(const std::string& poolName, std::shared_ptr<ServerPolicy> policy);
+void addServerToPool(const std::string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(const std::string& poolName, std::shared_ptr<DownstreamState> server);
 
 const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName);
 
-std::shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-
-std::shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
-std::shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
-std::shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq);
+std::optional<ServerPolicy::SelectedServerPosition> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::optional<ServerPolicy::SelectedServerPosition> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::optional<ServerPolicy::SelectedServerPosition> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq);
 
 #include <unordered_map>
 
index dfcf23001da35c1dcb0e49a8936b96d0c9ceb337..6670eddbf05ad6719faeee987b2cae4ece30e8f9 100644 (file)
@@ -725,28 +725,24 @@ void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list,
   *out = &list->ffiServers.at(idx);
 }
 
-static size_t dnsdist_ffi_servers_get_index_from_server(const ServerPolicy::NumberedServerVector& servers, const std::shared_ptr<DownstreamState>& server)
-{
-  for (const auto& pair : servers) {
-    if (pair.second == server) {
-      return pair.first - 1;
-    }
-  }
-  throw std::runtime_error("Unable to find servers in server list");
-}
-
 size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
 {
   (void)dq;
-  auto server = chashedFromHash(list->servers, hash);
-  return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+  auto serverPosition = chashedFromHash(list->servers, hash);
+  if (!serverPosition) {
+    throw std::runtime_error("Unable to find servers in server list");
+  }
+  return *serverPosition;
 }
 
 size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash)
 {
   (void)dq;
-  auto server = whashedFromHash(list->servers, hash);
-  return dnsdist_ffi_servers_get_index_from_server(list->servers, server);
+  auto serverPosition = whashedFromHash(list->servers, hash);
+  if (!serverPosition) {
+    throw std::runtime_error("Unable to find servers in server list");
+  }
+  return *serverPosition;
 }
 
 uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server)
index 3f9f2754b07b8568f780b43691ea8abddbf63f54..a1e58df1c761c71b3b052de13dd464c5263736c1 100644 (file)
@@ -1424,14 +1424,14 @@ static ProcessQueryResult handleQueryTurnedIntoSelfAnsweredResponse(DNSQuestion&
   return ProcessQueryResult::SendAnswer;
 }
 
-static void selectBackendForOutgoingQuery(DNSQuestion& dnsQuestion, const ServerPool& serverPool, std::shared_ptr<DownstreamState>& selectedBackend)
+static ServerPolicy::SelectedBackend selectBackendForOutgoingQuery(DNSQuestion& dnsQuestion, const ServerPool& serverPool)
 {
   const auto& policy = serverPool.policy != nullptr ? *serverPool.policy : *dnsdist::configuration::getCurrentRuntimeConfiguration().d_lbPolicy;
   const auto& servers = serverPool.getServers();
-  selectedBackend = policy.getSelectedBackend(servers, dnsQuestion);
+  return policy.getSelectedBackend(servers, dnsQuestion);
 }
 
-ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ptr<DownstreamState>& selectedBackend)
+ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ptr<DownstreamState>& outgoingBackend)
 {
   const uint16_t queryId = ntohs(dnsQuestion.getHeader()->id);
 
@@ -1440,7 +1440,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
       return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion);
     }
     const auto& serverPool = getPool(dnsQuestion.ids.poolName);
-    selectBackendForOutgoingQuery(dnsQuestion, serverPool, selectedBackend);
+    auto selectedBackend = selectBackendForOutgoingQuery(dnsQuestion, serverPool);
     bool willBeForwardedOverUDP = !dnsQuestion.overTCP() || dnsQuestion.ids.protocol == dnsdist::Protocol::DoH;
     if (selectedBackend && selectedBackend->isTCPOnly()) {
       willBeForwardedOverUDP = false;
@@ -1541,7 +1541,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
       if (dnsQuestion.ids.poolName != existingPool) {
         const auto& newServerPool = getPool(dnsQuestion.ids.poolName);
         dnsQuestion.ids.packetCache = newServerPool.packetCache;
-        selectBackendForOutgoingQuery(dnsQuestion, newServerPool, selectedBackend);
+        selectedBackend = selectBackendForOutgoingQuery(dnsQuestion, newServerPool);
       }
       else {
         dnsQuestion.ids.packetCache = serverPool.packetCache;
@@ -1581,6 +1581,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
     }
 
     selectedBackend->incQueriesCount();
+    outgoingBackend = selectedBackend.get();
     return ProcessQueryResult::PassToBackend;
   }
   catch (const std::exception& e) {
index 3d3a712b5ddc16d3a1b5e7d12e4f80a1bffc5ab7..3950fe053cd3f7a39dccb91d73751233bf73f3db 100644 (file)
@@ -553,7 +553,7 @@ BOOST_AUTO_TEST_CASE(test_lua)
     local counter = 0
     function luaroundrobin(servers, dq)
       counter = counter + 1
-      return servers[1 + (counter % #servers)]
+      return 1 + (counter % #servers)
     end
 
     setServerPolicyLua("luaroundrobin", luaroundrobin)
@@ -582,7 +582,9 @@ BOOST_AUTO_TEST_CASE(test_lua)
 
     for (const auto& name : names) {
       auto dnsQuestion = getDQ(&name);
-      auto server = pol->getSelectedBackend(servers, dnsQuestion);
+      auto selectedServer = pol->getSelectedBackend(servers, dnsQuestion);
+      BOOST_REQUIRE(selectedServer);
+      const auto& server = selectedServer.get();
       BOOST_REQUIRE(serversMap.count(server) == 1);
       ++serversMap[server];
     }
index 8da44193f3e5652b615286fd4664a9f0da28f199..8e143e3b10fd6a40a25f807163112fe0a24dd84a 100644 (file)
@@ -395,7 +395,7 @@ class TestRoutingCustomLuaRoundRobinLB(RoundRobinTest, DNSDistTest):
     local counter = 0
     function luaroundrobin(servers_list, dq)
       counter = counter + 1
-      return servers_list[(counter %% #servers_list)+1]
+      return (counter %% #servers_list)+1
     end
     setServerPolicy(newServerPolicy("custom lua round robin policy", luaroundrobin))