From: Remi Gacogne Date: Wed, 22 Jun 2022 09:18:32 +0000 (+0200) Subject: dnsdist: Fix a bound issue, improve readability (thanks, Otto!) X-Git-Tag: auth-4.8.0-alpha0~45^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cf6d420781d1a8e3ef1ae633b7fdc96f779cefe1;p=thirdparty%2Fpdns.git dnsdist: Fix a bound issue, improve readability (thanks, Otto!) --- diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.cc b/pdns/dnsdistdist/dnsdist-lbpolicies.cc index a6feacca72..855b9900a4 100644 --- a/pdns/dnsdistdist/dnsdist-lbpolicies.cc +++ b/pdns/dnsdistdist/dnsdist-lbpolicies.cc @@ -29,6 +29,10 @@ GlobalStateHolder g_policy; bool g_roundrobinFailOnNoServer{false}; +static constexpr size_t s_staticArrayCutOff = 16; +template using DynamicIndexArray = std::vector>; +template using StaticIndexArray = std::array, s_staticArrayCutOff>; + template static std::shared_ptr getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss) { /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort, @@ -45,7 +49,7 @@ template static std::shared_ptr getLeastOutstanding(c return shared_ptr(); } - nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; }); + std::nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; }); // minus 1 because the NumberedServerVector starts at 1 for Lua return servers.at(poss.begin()->second - 1).second; } @@ -53,16 +57,18 @@ template static std::shared_ptr getLeastOutstanding(c // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest shared_ptr leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { + using LeastOustandingType = std::tuple; + if (servers.size() == 1 && servers[0].second->isUp()) { return servers[0].second; } - if (servers.size() <= 16) { - std::array, size_t>, 16> poss; + if (servers.size() <= s_staticArrayCutOff) { + StaticIndexArray poss; return getLeastOutstanding(servers, poss); } - vector, size_t>> poss; + DynamicIndexArray poss; poss.resize(servers.size()); return getLeastOutstanding(servers, poss); } @@ -81,8 +87,8 @@ double g_weightedBalancingFactor = 0; template static std::shared_ptr getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad) { + constexpr int max = std::numeric_limits::max(); int sum = 0; - int max = std::numeric_limits::max(); size_t usableServers = 0; for (const auto& d : servers) { // w=1, w=10 -> 1, 11 @@ -99,14 +105,14 @@ template static std::shared_ptr getValRandom(const Se } } - // Catch poss & sum are empty to avoid SIGFPE + // Catch the case where usableServers or sum are equal to 0 to avoid a SIGFPE if (usableServers == 0 || sum == 0) { return shared_ptr(); } int r = val % sum; - auto p = upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return r_ < a.first;}); - if (p == poss.end()) { + auto p = std::upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return r_ < a.first;}); + if (p == poss.begin() + usableServers) { return shared_ptr(); } @@ -116,6 +122,7 @@ template static std::shared_ptr getValRandom(const Se static shared_ptr valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers) { + using ValRandomType = int; double targetLoad = std::numeric_limits::max(); if (g_weightedBalancingFactor > 0) { @@ -134,12 +141,12 @@ static shared_ptr valrandom(const unsigned int val, const Serve } } - if (servers.size() <= 16) { - std::array, 16> poss; + if (servers.size() <= s_staticArrayCutOff) { + StaticIndexArray poss; return getValRandom(servers, poss, val, targetLoad); } - vector> poss; + DynamicIndexArray poss; poss.resize(servers.size()); return getValRandom(servers, poss, val, targetLoad); } @@ -269,7 +276,7 @@ std::shared_ptr createPoolIfNotExists(pools_t& pools, const string& if (!poolName.empty()) vinfolog("Creating pool %s", poolName); pool = std::make_shared(); - pools.insert(std::pair >(poolName, pool)); + pools.insert(std::pair >(poolName, pool)); } return pool; }