From: Remi Gacogne Date: Mon, 21 Jul 2025 15:04:40 +0000 (+0200) Subject: dnsdist: Refactor load-balancing policies X-Git-Tag: rec-5.4.0-alpha1~224^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2f654e35a979e018975f59d3e5a2db00fc9a150f;p=thirdparty%2Fpdns.git dnsdist: Refactor load-balancing policies Since we no longer need to increase the reference counter of the returned backend (the runtime configuration cannot be updated be updated under our feet anymore), we can return the position of the selected backend in the initial array instead, significantly reducing the performance cost of the load-balancing policies. Signed-off-by: Remi Gacogne --- diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 6857c6fd7f..a493f453a2 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -244,6 +244,7 @@ dnsdist_SOURCES = \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-secpoll.cc dnsdist-secpoll.hh \ dnsdist-self-answers.cc dnsdist-self-answers.hh \ + dnsdist-server-pool.hh \ dnsdist-session-cache.cc dnsdist-session-cache.hh \ dnsdist-snmp.cc dnsdist-snmp.hh \ dnsdist-svc.cc dnsdist-svc.hh \ diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.cc b/pdns/dnsdistdist/dnsdist-lbpolicies.cc index 58075c41a5..f1de02ea18 100644 --- a/pdns/dnsdistdist/dnsdist-lbpolicies.cc +++ b/pdns/dnsdistdist/dnsdist-lbpolicies.cc @@ -31,7 +31,7 @@ static constexpr size_t s_staticArrayCutOff = 16; template using DynamicIndexArray = std::vector>; template using StaticIndexArray = std::array, s_staticArrayCutOff>; -template static std::shared_ptr getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss) +template static std::optional getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss) { /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort, which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */ @@ -44,22 +44,21 @@ template static std::shared_ptr getLeastOutstanding(c } if (usableServers == 0) { - return shared_ptr(); + return std::nullopt; } std::nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; }); - // minus 1 because the NumberedServerVector starts at 1 for Lua - return servers.at(poss.begin()->second - 1).second; + return poss.begin()->second; } // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest -shared_ptr leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { (void)dq; using LeastOutstandingType = std::tuple; if (servers.size() == 1 && servers[0].second->isUp()) { - return servers[0].second; + return 1; } if (servers.size() <= s_staticArrayCutOff) { @@ -72,17 +71,17 @@ shared_ptr leastOutstanding(const ServerPolicy::NumberedServerV return getLeastOutstanding(servers, poss); } -shared_ptr firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { for (auto& d : servers) { if (d.second->isUp() && (!d.second->d_qpsLimiter || d.second->d_qpsLimiter->checkOnly())) { - return d.second; + return d.first; } } return leastOutstanding(servers, dq); } -template static std::shared_ptr getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad) +template static std::optional getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad) { constexpr int max = std::numeric_limits::max(); int sum = 0; @@ -105,20 +104,19 @@ template static std::shared_ptr getValRandom(const Se // Catch the case where usableServers or sum are equal to 0 to avoid a SIGFPE if (usableServers == 0 || sum == 0) { - return shared_ptr(); + return std::nullopt; } int r = val % sum; auto p = std::upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return r_ < a.first;}); if (p == poss.begin() + usableServers) { - return shared_ptr(); + return std::nullopt; } - // minus 1 because the NumberedServerVector starts at 1 for Lua - return servers.at(p->second - 1).second; + return p->second; } -static shared_ptr valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers) +static std::optional valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers) { using ValRandomType = int; double targetLoad = std::numeric_limits::max(); @@ -149,28 +147,28 @@ static shared_ptr valrandom(const unsigned int val, const Serve return getValRandom(servers, poss, val, targetLoad); } -shared_ptr wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { (void)dq; return valrandom(dns_random_uint32(), servers); } -shared_ptr whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash) +std::optional whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash) { return valrandom(hash, servers); } -shared_ptr whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation; return whashedFromHash(servers, dq->ids.qname.hash(hashPerturbation)); } -shared_ptr chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash) +std::optional chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash) { unsigned int sel = std::numeric_limits::max(); unsigned int min = std::numeric_limits::max(); - shared_ptr ret = nullptr, first = nullptr; + std::optional ret, first; double targetLoad = std::numeric_limits::max(); const auto consistentHashBalancingFactor = dnsdist::configuration::getImmutableConfiguration().d_consistentHashBalancingFactor; @@ -197,44 +195,45 @@ shared_ptr chashedFromHash(const ServerPolicy::NumberedServerVe d.second->hash(); } { + const auto position = d.first; const auto& server = d.second; auto hashes = server->hashes.read_lock(); // we want to keep track of the last hash if (min > *(hashes->begin())) { min = *(hashes->begin()); - first = server; + first = position; } auto hash_it = std::lower_bound(hashes->begin(), hashes->end(), qhash); if (hash_it != hashes->end()) { if (*hash_it < sel) { sel = *hash_it; - ret = server; + ret = position; } } } } } - if (ret != nullptr) { + if (ret) { return ret; } - if (first != nullptr) { + if (first) { return first; } - return shared_ptr(); + return std::nullopt; } -shared_ptr chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation; return chashedFromHash(servers, dq->ids.qname.hash(hashPerturbation)); } -shared_ptr roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +std::optional roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { (void)dq; if (servers.empty()) { - return shared_ptr(); + return std::nullopt; } vector candidates; @@ -248,7 +247,7 @@ shared_ptr roundrobin(const ServerPolicy::NumberedServerVector& if (candidates.empty()) { if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_roundrobinFailOnNoServer) { - return shared_ptr(); + return std::nullopt; } for (auto& d : servers) { candidates.push_back(d.first); @@ -256,17 +255,19 @@ shared_ptr roundrobin(const ServerPolicy::NumberedServerVector& } static unsigned int counter; - return servers.at(candidates.at((counter++) % candidates.size()) - 1).second; + return candidates.at((counter++) % candidates.size()); } -shared_ptr orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq) +std::optional orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq) { if (servers.empty()) { - return {}; + return std::nullopt; } ServerPolicy::NumberedServerVector candidates; candidates.reserve(servers.size()); + std::vector positionsMap; + positionsMap.reserve(servers.size()); int curOrder = std::numeric_limits::max(); unsigned int curNumber = 1; @@ -279,14 +280,19 @@ shared_ptr orderedWrandUntag(const ServerPolicy::NumberedServer } curOrder = svr.second->d_config.order; candidates.push_back(ServerPolicy::NumberedServer(curNumber++, svr.second)); + positionsMap.push_back(svr.first); } } if (candidates.empty()) { - return {}; + return std::nullopt; } - return wrandom(candidates, dnsq); + auto selected = wrandom(candidates, dnsq); + if (selected) { + return positionsMap.at(*selected - 1); + } + return selected; } const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName) @@ -406,42 +412,51 @@ const ServerPolicy::ffipolicyfunc_t& ServerPolicy::getPerThreadPolicy() const return state->d_policies.at(d_name); } -std::shared_ptr ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const +ServerPolicy::SelectedBackend ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const { - std::shared_ptr selectedBackend{nullptr}; + ServerPolicy::SelectedBackend result{servers}; if (d_isLua) { if (!d_isFFI) { - auto lock = g_lua.lock(); - selectedBackend = d_policy(servers, &dq); - } - else { - dnsdist_ffi_dnsquestion_t dnsq(&dq); - dnsdist_ffi_servers_list_t serversList(servers); - unsigned int selected = 0; - - if (!d_isPerThread) { + std::optional position; + { auto lock = g_lua.lock(); - selected = d_ffipolicy(&serversList, &dnsq); + position = d_policy(servers, &dq); } - else { - const auto& policy = getPerThreadPolicy(); - selected = policy(&serversList, &dnsq); + if (position && *position > 0 && *position <= servers.size()) { + result.setSelected(*position - 1); } + return result; + } - if (selected >= servers.size()) { - /* invalid offset, meaning that there is no server available */ - return {}; - } + dnsdist_ffi_dnsquestion_t dnsq(&dq); + dnsdist_ffi_servers_list_t serversList(servers); + ServerPolicy::SelectedServerPosition selected = 0; - selectedBackend = servers.at(selected).second; + if (!d_isPerThread) { + auto lock = g_lua.lock(); + selected = d_ffipolicy(&serversList, &dnsq); + } + else { + const auto& policy = getPerThreadPolicy(); + selected = policy(&serversList, &dnsq); + } + + if (selected >= servers.size()) { + /* invalid offset, meaning that there is no server available */ + return result; } + + result.setSelected(selected); + return result; } - else { - selectedBackend = d_policy(servers, &dq); + + auto position = d_policy(servers, &dq); + if (position && *position > 0 && *position <= servers.size()) { + result.setSelected(*position - 1); } - return selectedBackend; + return result; } namespace dnsdist::lbpolicies diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.hh b/pdns/dnsdistdist/dnsdist-lbpolicies.hh index a6ea89f706..554b4adaa2 100644 --- a/pdns/dnsdistdist/dnsdist-lbpolicies.hh +++ b/pdns/dnsdistdist/dnsdist-lbpolicies.hh @@ -21,6 +21,9 @@ */ #pragma once +#include +#include + struct dnsdist_ffi_servers_list_t; struct dnsdist_ffi_server_t; struct dnsdist_ffi_dnsquestion_t; @@ -33,14 +36,15 @@ struct PerThreadPoliciesState; class ServerPolicy { public: + using SelectedServerPosition = unsigned int; template using Numbered = std::pair; - using NumberedServer = Numbered>; + using NumberedServer = Numbered>; template using NumberedVector = std::vector>; - using NumberedServerVector = NumberedVector>; - using policyfunc_t = std::function(const NumberedServerVector& servers, const DNSQuestion*)>; - using ffipolicyfunc_t = std::function; + using NumberedServerVector = NumberedVector>; + using policyfunc_t = std::function(const NumberedServerVector& servers, const DNSQuestion*)>; + using ffipolicyfunc_t = std::function; ServerPolicy(const std::string& name_, policyfunc_t policy_, bool isLua_) : d_name(name_), d_policy(std::move(policy_)), d_isLua(isLua_) @@ -59,7 +63,43 @@ public: { } - std::shared_ptr getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const; + class SelectedBackend + { + public: + SelectedBackend(const NumberedServerVector& backends) : + d_backends(&backends) + { + } + + void setSelected(SelectedServerPosition selected) + { + if (selected >= d_backends->size()) { + throw std::runtime_error("Setting an invalid backend position (" + std::to_string(selected) + " out of " + std::to_string(d_backends->size()) + ") from the server policy"); + } + d_selected = selected; + } + + operator bool() const noexcept + { + return d_selected.has_value(); + } + + DownstreamState* operator->() const noexcept + { + return (*d_backends)[*d_selected].second.get(); + } + + const std::shared_ptr& get() const noexcept + { + return (*d_backends)[*d_selected].second; + } + + private: + const NumberedServerVector* d_backends{nullptr}; + std::optional d_selected; + }; + + SelectedBackend getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const; const std::string& getName() const { @@ -68,7 +108,7 @@ public: std::string toString() const { - return string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\""; + return std::string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\""; } private: @@ -93,23 +133,22 @@ struct ServerPool; using pools_t = std::map>; const ServerPool& getPool(const std::string& poolName); -const ServerPool& createPoolIfNotExists(const string& poolName); -void setPoolPolicy(const string& poolName, std::shared_ptr policy); -void addServerToPool(const string& poolName, std::shared_ptr server); -void removeServerFromPool(const string& poolName, std::shared_ptr server); +const ServerPool& createPoolIfNotExists(const std::string& poolName); +void setPoolPolicy(const std::string& poolName, std::shared_ptr policy); +void addServerToPool(const std::string& poolName, std::shared_ptr server); +void removeServerFromPool(const std::string& poolName, std::shared_ptr server); const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName); -std::shared_ptr firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); - -std::shared_ptr leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash); -std::shared_ptr chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash); -std::shared_ptr roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); -std::shared_ptr orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq); +std::optional firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash); +std::optional chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash); +std::optional roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq); +std::optional orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq); #include diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index dfcf23001d..6670eddbf0 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -725,28 +725,24 @@ void dnsdist_ffi_servers_list_get_server(const dnsdist_ffi_servers_list_t* list, *out = &list->ffiServers.at(idx); } -static size_t dnsdist_ffi_servers_get_index_from_server(const ServerPolicy::NumberedServerVector& servers, const std::shared_ptr& server) -{ - for (const auto& pair : servers) { - if (pair.second == server) { - return pair.first - 1; - } - } - throw std::runtime_error("Unable to find servers in server list"); -} - size_t dnsdist_ffi_servers_list_chashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) { (void)dq; - auto server = chashedFromHash(list->servers, hash); - return dnsdist_ffi_servers_get_index_from_server(list->servers, server); + auto serverPosition = chashedFromHash(list->servers, hash); + if (!serverPosition) { + throw std::runtime_error("Unable to find servers in server list"); + } + return *serverPosition; } size_t dnsdist_ffi_servers_list_whashed(const dnsdist_ffi_servers_list_t* list, const dnsdist_ffi_dnsquestion_t* dq, size_t hash) { (void)dq; - auto server = whashedFromHash(list->servers, hash); - return dnsdist_ffi_servers_get_index_from_server(list->servers, server); + auto serverPosition = whashedFromHash(list->servers, hash); + if (!serverPosition) { + throw std::runtime_error("Unable to find servers in server list"); + } + return *serverPosition; } uint64_t dnsdist_ffi_server_get_outstanding(const dnsdist_ffi_server_t* server) diff --git a/pdns/dnsdistdist/dnsdist.cc b/pdns/dnsdistdist/dnsdist.cc index 3f9f2754b0..a1e58df1c7 100644 --- a/pdns/dnsdistdist/dnsdist.cc +++ b/pdns/dnsdistdist/dnsdist.cc @@ -1424,14 +1424,14 @@ static ProcessQueryResult handleQueryTurnedIntoSelfAnsweredResponse(DNSQuestion& return ProcessQueryResult::SendAnswer; } -static void selectBackendForOutgoingQuery(DNSQuestion& dnsQuestion, const ServerPool& serverPool, std::shared_ptr& selectedBackend) +static ServerPolicy::SelectedBackend selectBackendForOutgoingQuery(DNSQuestion& dnsQuestion, const ServerPool& serverPool) { const auto& policy = serverPool.policy != nullptr ? *serverPool.policy : *dnsdist::configuration::getCurrentRuntimeConfiguration().d_lbPolicy; const auto& servers = serverPool.getServers(); - selectedBackend = policy.getSelectedBackend(servers, dnsQuestion); + return policy.getSelectedBackend(servers, dnsQuestion); } -ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ptr& selectedBackend) +ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ptr& outgoingBackend) { const uint16_t queryId = ntohs(dnsQuestion.getHeader()->id); @@ -1440,7 +1440,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ return handleQueryTurnedIntoSelfAnsweredResponse(dnsQuestion); } const auto& serverPool = getPool(dnsQuestion.ids.poolName); - selectBackendForOutgoingQuery(dnsQuestion, serverPool, selectedBackend); + auto selectedBackend = selectBackendForOutgoingQuery(dnsQuestion, serverPool); bool willBeForwardedOverUDP = !dnsQuestion.overTCP() || dnsQuestion.ids.protocol == dnsdist::Protocol::DoH; if (selectedBackend && selectedBackend->isTCPOnly()) { willBeForwardedOverUDP = false; @@ -1541,7 +1541,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ if (dnsQuestion.ids.poolName != existingPool) { const auto& newServerPool = getPool(dnsQuestion.ids.poolName); dnsQuestion.ids.packetCache = newServerPool.packetCache; - selectBackendForOutgoingQuery(dnsQuestion, newServerPool, selectedBackend); + selectedBackend = selectBackendForOutgoingQuery(dnsQuestion, newServerPool); } else { dnsQuestion.ids.packetCache = serverPool.packetCache; @@ -1581,6 +1581,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_ } selectedBackend->incQueriesCount(); + outgoingBackend = selectedBackend.get(); return ProcessQueryResult::PassToBackend; } catch (const std::exception& e) { diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index 3d3a712b5d..3950fe053c 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -553,7 +553,7 @@ BOOST_AUTO_TEST_CASE(test_lua) local counter = 0 function luaroundrobin(servers, dq) counter = counter + 1 - return servers[1 + (counter % #servers)] + return 1 + (counter % #servers) end setServerPolicyLua("luaroundrobin", luaroundrobin) @@ -582,7 +582,9 @@ BOOST_AUTO_TEST_CASE(test_lua) for (const auto& name : names) { auto dnsQuestion = getDQ(&name); - auto server = pol->getSelectedBackend(servers, dnsQuestion); + auto selectedServer = pol->getSelectedBackend(servers, dnsQuestion); + BOOST_REQUIRE(selectedServer); + const auto& server = selectedServer.get(); BOOST_REQUIRE(serversMap.count(server) == 1); ++serversMap[server]; } diff --git a/regression-tests.dnsdist/test_Routing.py b/regression-tests.dnsdist/test_Routing.py index 8da44193f3..8e143e3b10 100644 --- a/regression-tests.dnsdist/test_Routing.py +++ b/regression-tests.dnsdist/test_Routing.py @@ -395,7 +395,7 @@ class TestRoutingCustomLuaRoundRobinLB(RoundRobinTest, DNSDistTest): local counter = 0 function luaroundrobin(servers_list, dq) counter = counter + 1 - return servers_list[(counter %% #servers_list)+1] + return (counter %% #servers_list)+1 end setServerPolicy(newServerPolicy("custom lua round robin policy", luaroundrobin))