template <typename T> using DynamicIndexArray = std::vector<std::pair<T, size_t>>;
template <typename T> using StaticIndexArray = std::array<std::pair<T, size_t>, s_staticArrayCutOff>;
-template <class T> static std::shared_ptr<DownstreamState> getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss)
+template <class T> static std::optional<ServerPolicy::SelectedServerPosition> getLeastOutstanding(const ServerPolicy::NumberedServerVector& servers, T& poss)
{
/* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
}
if (usableServers == 0) {
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
std::nth_element(poss.begin(), poss.begin(), poss.begin() + usableServers, [](const typename T::value_type& a, const typename T::value_type& b) { return a.first < b.first; });
- // minus 1 because the NumberedServerVector starts at 1 for Lua
- return servers.at(poss.begin()->second - 1).second;
+ return poss.begin()->second;
}
// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
-shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
(void)dq;
using LeastOutstandingType = std::tuple<int,int,double>;
if (servers.size() == 1 && servers[0].second->isUp()) {
- return servers[0].second;
+ return 1;
}
if (servers.size() <= s_staticArrayCutOff) {
return getLeastOutstanding(servers, poss);
}
-shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
for (auto& d : servers) {
if (d.second->isUp() && (!d.second->d_qpsLimiter || d.second->d_qpsLimiter->checkOnly())) {
- return d.second;
+ return d.first;
}
}
return leastOutstanding(servers, dq);
}
-template <class T> static std::shared_ptr<DownstreamState> getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad)
+template <class T> static std::optional<ServerPolicy::SelectedServerPosition> getValRandom(const ServerPolicy::NumberedServerVector& servers, T& poss, const unsigned int val, const double targetLoad)
{
constexpr int max = std::numeric_limits<int>::max();
int sum = 0;
// Catch the case where usableServers or sum are equal to 0 to avoid a SIGFPE
if (usableServers == 0 || sum == 0) {
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
int r = val % sum;
auto p = std::upper_bound(poss.begin(), poss.begin() + usableServers, r, [](int r_, const typename T::value_type& a) { return r_ < a.first;});
if (p == poss.begin() + usableServers) {
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
- // minus 1 because the NumberedServerVector starts at 1 for Lua
- return servers.at(p->second - 1).second;
+ return p->second;
}
-static shared_ptr<DownstreamState> valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers)
+static std::optional<ServerPolicy::SelectedServerPosition> valrandom(const unsigned int val, const ServerPolicy::NumberedServerVector& servers)
{
using ValRandomType = int;
double targetLoad = std::numeric_limits<double>::max();
return getValRandom(servers, poss, val, targetLoad);
}
-shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
(void)dq;
return valrandom(dns_random_uint32(), servers);
}
-shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
+std::optional<ServerPolicy::SelectedServerPosition> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash)
{
return valrandom(hash, servers);
}
-shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation;
return whashedFromHash(servers, dq->ids.qname.hash(hashPerturbation));
}
-shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
+std::optional<ServerPolicy::SelectedServerPosition> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash)
{
unsigned int sel = std::numeric_limits<unsigned int>::max();
unsigned int min = std::numeric_limits<unsigned int>::max();
- shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
+ std::optional<ServerPolicy::SelectedServerPosition> ret, first;
double targetLoad = std::numeric_limits<double>::max();
const auto consistentHashBalancingFactor = dnsdist::configuration::getImmutableConfiguration().d_consistentHashBalancingFactor;
d.second->hash();
}
{
+ const auto position = d.first;
const auto& server = d.second;
auto hashes = server->hashes.read_lock();
// we want to keep track of the last hash
if (min > *(hashes->begin())) {
min = *(hashes->begin());
- first = server;
+ first = position;
}
auto hash_it = std::lower_bound(hashes->begin(), hashes->end(), qhash);
if (hash_it != hashes->end()) {
if (*hash_it < sel) {
sel = *hash_it;
- ret = server;
+ ret = position;
}
}
}
}
}
- if (ret != nullptr) {
+ if (ret) {
return ret;
}
- if (first != nullptr) {
+ if (first) {
return first;
}
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
-shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
const auto hashPerturbation = dnsdist::configuration::getImmutableConfiguration().d_hashPerturbation;
return chashedFromHash(servers, dq->ids.qname.hash(hashPerturbation));
}
-shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
+std::optional<ServerPolicy::SelectedServerPosition> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq)
{
(void)dq;
if (servers.empty()) {
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
vector<size_t> candidates;
if (candidates.empty()) {
if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_roundrobinFailOnNoServer) {
- return shared_ptr<DownstreamState>();
+ return std::nullopt;
}
for (auto& d : servers) {
candidates.push_back(d.first);
}
static unsigned int counter;
- return servers.at(candidates.at((counter++) % candidates.size()) - 1).second;
+ return candidates.at((counter++) % candidates.size());
}
-shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq)
+std::optional<ServerPolicy::SelectedServerPosition> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq)
{
if (servers.empty()) {
- return {};
+ return std::nullopt;
}
ServerPolicy::NumberedServerVector candidates;
candidates.reserve(servers.size());
+ std::vector<ServerPolicy::SelectedServerPosition> positionsMap;
+ positionsMap.reserve(servers.size());
int curOrder = std::numeric_limits<int>::max();
unsigned int curNumber = 1;
}
curOrder = svr.second->d_config.order;
candidates.push_back(ServerPolicy::NumberedServer(curNumber++, svr.second));
+ positionsMap.push_back(svr.first);
}
}
if (candidates.empty()) {
- return {};
+ return std::nullopt;
}
- return wrandom(candidates, dnsq);
+ auto selected = wrandom(candidates, dnsq);
+ if (selected) {
+ return positionsMap.at(*selected - 1);
+ }
+ return selected;
}
const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName)
return state->d_policies.at(d_name);
}
-std::shared_ptr<DownstreamState> ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const
+ServerPolicy::SelectedBackend ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const
{
- std::shared_ptr<DownstreamState> selectedBackend{nullptr};
+ ServerPolicy::SelectedBackend result{servers};
if (d_isLua) {
if (!d_isFFI) {
- auto lock = g_lua.lock();
- selectedBackend = d_policy(servers, &dq);
- }
- else {
- dnsdist_ffi_dnsquestion_t dnsq(&dq);
- dnsdist_ffi_servers_list_t serversList(servers);
- unsigned int selected = 0;
-
- if (!d_isPerThread) {
+ std::optional<SelectedServerPosition> position;
+ {
auto lock = g_lua.lock();
- selected = d_ffipolicy(&serversList, &dnsq);
+ position = d_policy(servers, &dq);
}
- else {
- const auto& policy = getPerThreadPolicy();
- selected = policy(&serversList, &dnsq);
+ if (position && *position > 0 && *position <= servers.size()) {
+ result.setSelected(*position - 1);
}
+ return result;
+ }
- if (selected >= servers.size()) {
- /* invalid offset, meaning that there is no server available */
- return {};
- }
+ dnsdist_ffi_dnsquestion_t dnsq(&dq);
+ dnsdist_ffi_servers_list_t serversList(servers);
+ ServerPolicy::SelectedServerPosition selected = 0;
- selectedBackend = servers.at(selected).second;
+ if (!d_isPerThread) {
+ auto lock = g_lua.lock();
+ selected = d_ffipolicy(&serversList, &dnsq);
+ }
+ else {
+ const auto& policy = getPerThreadPolicy();
+ selected = policy(&serversList, &dnsq);
+ }
+
+ if (selected >= servers.size()) {
+ /* invalid offset, meaning that there is no server available */
+ return result;
}
+
+ result.setSelected(selected);
+ return result;
}
- else {
- selectedBackend = d_policy(servers, &dq);
+
+ auto position = d_policy(servers, &dq);
+ if (position && *position > 0 && *position <= servers.size()) {
+ result.setSelected(*position - 1);
}
- return selectedBackend;
+ return result;
}
namespace dnsdist::lbpolicies
*/
#pragma once
+#include <memory>
+#include <optional>
+
struct dnsdist_ffi_servers_list_t;
struct dnsdist_ffi_server_t;
struct dnsdist_ffi_dnsquestion_t;
class ServerPolicy
{
public:
+ using SelectedServerPosition = unsigned int;
template <class T>
using Numbered = std::pair<unsigned int, T>;
- using NumberedServer = Numbered<shared_ptr<DownstreamState>>;
+ using NumberedServer = Numbered<std::shared_ptr<DownstreamState>>;
template <class T>
using NumberedVector = std::vector<std::pair<unsigned int, T>>;
- using NumberedServerVector = NumberedVector<shared_ptr<DownstreamState>>;
- using policyfunc_t = std::function<std::shared_ptr<DownstreamState>(const NumberedServerVector& servers, const DNSQuestion*)>;
- using ffipolicyfunc_t = std::function<unsigned int(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)>;
+ using NumberedServerVector = NumberedVector<std::shared_ptr<DownstreamState>>;
+ using policyfunc_t = std::function<std::optional<SelectedServerPosition>(const NumberedServerVector& servers, const DNSQuestion*)>;
+ using ffipolicyfunc_t = std::function<SelectedServerPosition(dnsdist_ffi_servers_list_t* servers, dnsdist_ffi_dnsquestion_t* dq)>;
ServerPolicy(const std::string& name_, policyfunc_t policy_, bool isLua_) :
d_name(name_), d_policy(std::move(policy_)), d_isLua(isLua_)
{
}
- std::shared_ptr<DownstreamState> getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const;
+ class SelectedBackend
+ {
+ public:
+ SelectedBackend(const NumberedServerVector& backends) :
+ d_backends(&backends)
+ {
+ }
+
+ void setSelected(SelectedServerPosition selected)
+ {
+ if (selected >= d_backends->size()) {
+ throw std::runtime_error("Setting an invalid backend position (" + std::to_string(selected) + " out of " + std::to_string(d_backends->size()) + ") from the server policy");
+ }
+ d_selected = selected;
+ }
+
+ operator bool() const noexcept
+ {
+ return d_selected.has_value();
+ }
+
+ DownstreamState* operator->() const noexcept
+ {
+ return (*d_backends)[*d_selected].second.get();
+ }
+
+ const std::shared_ptr<DownstreamState>& get() const noexcept
+ {
+ return (*d_backends)[*d_selected].second;
+ }
+
+ private:
+ const NumberedServerVector* d_backends{nullptr};
+ std::optional<SelectedServerPosition> d_selected;
+ };
+
+ SelectedBackend getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const;
const std::string& getName() const
{
std::string toString() const
{
- return string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\"";
+ return std::string("ServerPolicy") + (d_isLua ? " (Lua)" : "") + " \"" + d_name + "\"";
}
private:
using pools_t = std::map<std::string, std::shared_ptr<ServerPool>>;
const ServerPool& getPool(const std::string& poolName);
-const ServerPool& createPoolIfNotExists(const string& poolName);
-void setPoolPolicy(const string& poolName, std::shared_ptr<ServerPolicy> policy);
-void addServerToPool(const string& poolName, std::shared_ptr<DownstreamState> server);
-void removeServerFromPool(const string& poolName, std::shared_ptr<DownstreamState> server);
+const ServerPool& createPoolIfNotExists(const std::string& poolName);
+void setPoolPolicy(const std::string& poolName, std::shared_ptr<ServerPolicy> policy);
+void addServerToPool(const std::string& poolName, std::shared_ptr<DownstreamState> server);
+void removeServerFromPool(const std::string& poolName, std::shared_ptr<DownstreamState> server);
const ServerPolicy::NumberedServerVector& getDownstreamCandidates(const std::string& poolName);
-std::shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-
-std::shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
-std::shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
-std::shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
-std::shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq);
+std::optional<ServerPolicy::SelectedServerPosition> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::optional<ServerPolicy::SelectedServerPosition> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash);
+std::optional<ServerPolicy::SelectedServerPosition> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq);
+std::optional<ServerPolicy::SelectedServerPosition> orderedWrandUntag(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dnsq);
#include <unordered_map>