]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix a bound issue, improve readability (thanks, Otto!)
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 22 Jun 2022 09:18:32 +0000 (11:18 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 22 Jun 2022 09:18:32 +0000 (11:18 +0200)
pdns/dnsdistdist/dnsdist-lbpolicies.cc

index a6feacca728bb8d5404ce479ba31a5c2d97568bd..855b9900a4beb18a86070f288f6454559a0bfa38 100644 (file)
 GlobalStateHolder<ServerPolicy> g_policy;
 bool g_roundrobinFailOnNoServer{false};
 
+static constexpr size_t s_staticArrayCutOff = 16;
+template <typename T> using DynamicIndexArray = std::vector<std::pair<T, size_t>>;
+template <typename T> using StaticIndexArray = std::array<std::pair<T, size_t>, s_staticArrayCutOff>;
+
 template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss)
 {
   /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
@@ -45,7 +49,7 @@ template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(c
     return shared_ptr<DownstreamState>();
   }
 
-  nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; });
+  std::nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; });
   // minus 1 because the NumberedServerVector starts at 1 for Lua
   return servers.at(poss.begin()->second - 1).second;
 }
@@ -53,16 +57,18 @@ template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(c
 // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
 shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
 {
+  using LeastOustandingType = std::tuple<int,int,double>;
+
   if (servers.size() == 1 && servers[0].second->isUp()) {
     return servers[0].second;
   }
 
-  if (servers.size() <= 16) {
-    std::array<pair<std::tuple<int,int,double>, size_t>, 16> poss;
+  if (servers.size() <= s_staticArrayCutOff) {
+    StaticIndexArray<LeastOustandingType> poss;
     return getLeastOutstanding(servers, poss);
   }
 
-  vector<pair<std::tuple<int,int,double>, size_t>> poss;
+  DynamicIndexArray<LeastOustandingType> poss;
   poss.resize(servers.size());
   return getLeastOutstanding(servers, poss);
 }
@@ -81,8 +87,8 @@ double g_weightedBalancingFactor = 0;
 
 template <class T> static std::shared_ptr<DownstreamState> getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad)
 {
+  constexpr int max = std::numeric_limits<int>::max();
   int sum = 0;
-  int max = std::numeric_limits<int>::max();
 
   size_t usableServers = 0;
   for (const auto& d : servers) {      // w=1, w=10 -> 1, 11
@@ -99,14 +105,14 @@ template <class T> static std::shared_ptr<DownstreamState> getValRandom(const Se
     }
   }
 
-  // Catch poss & sum are empty to avoid SIGFPE
+  // Catch the case where usableServers or sum are equal to 0 to avoid a SIGFPE
   if (usableServers == 0 || sum == 0) {
     return shared_ptr<DownstreamState>();
   }
 
   int r = val % sum;
-  auto p = upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return  r_ < a.first;});
-  if (p == poss.end()) {
+  auto p = std::upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return  r_ < a.first;});
+  if (p == poss.begin() + usableServers) {
     return shared_ptr<DownstreamState>();
   }
 
@@ -116,6 +122,7 @@ template <class T> static std::shared_ptr<DownstreamState> getValRandom(const Se
 
 static shared_ptr<DownstreamState> valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers)
 {
+  using ValRandomType = int;
   double targetLoad = std::numeric_limits<double>::max();
 
   if (g_weightedBalancingFactor > 0) {
@@ -134,12 +141,12 @@ static shared_ptr<DownstreamState> valrandom(const unsigned int val, const Serve
     }
   }
 
-  if (servers.size() <= 16) {
-    std::array<pair<int, size_t>, 16> poss;
+  if (servers.size() <= s_staticArrayCutOff) {
+    StaticIndexArray<ValRandomType> poss;
     return getValRandom(servers, poss, val, targetLoad);
   }
 
-  vector<pair<int, size_t>> poss;
+  DynamicIndexArray<ValRandomType> poss;
   poss.resize(servers.size());
   return getValRandom(servers, poss, val, targetLoad);
 }
@@ -269,7 +276,7 @@ std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string&
     if (!poolName.empty())
       vinfolog("Creating pool %s", poolName);
     pool = std::make_shared<ServerPool>();
-    pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
+    pools.insert(std::pair<std::string, std::shared_ptr<ServerPool> >(poolName, pool));
   }
   return pool;
 }