From: Remi Gacogne Date: Sun, 28 Mar 2021 15:09:00 +0000 (+0100) Subject: dnsdist: Allow randomly selecting a backend socket when several are available X-Git-Tag: rec-4.7.0-alpha1~11^2~6 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f573af851a4baa2339dd7cbe1c4e0e77d6ad55ec;p=thirdparty%2Fpdns.git dnsdist: Allow randomly selecting a backend socket when several are available --- diff --git a/pdns/dns_random.cc b/pdns/dns_random.cc index b6c70a46f0..7f75b4908c 100644 --- a/pdns/dns_random.cc +++ b/pdns/dns_random.cc @@ -209,32 +209,14 @@ void dns_random_init(const string& data __attribute__((unused)), bool force) { #endif } -/* Parts of this code come from arc4random_uniform */ uint32_t dns_random(uint32_t upper_bound) { if (chosen_rng == RNG_UNINITIALIZED) dns_random_setup(); - unsigned int min; if (upper_bound < 2) return 0; - /* To avoid "modulo bias" for some methods, calculate - minimum acceptable value for random number to improve - uniformity. - On applicable rngs, we loop until the rng spews out - value larger than min, and then take modulo out of that. - */ -#if (ULONG_MAX > 0xffffffffUL) - min = 0x100000000UL % upper_bound; -#else - /* Calculate (2**32 % upper_bound) avoiding 64-bit math */ - if (upper_bound > 0x80000000) - min = 1 + ~upper_bound; /* 2**32 - upper_bound */ - else { - /* (2**32 - (x * 2)) % x == 2**32 % x when x <= 2**31 */ - min = ((0xffffffff - (upper_bound * 2)) + 1) % upper_bound; - } -#endif + unsigned int min = pdns::random_minimum_acceptable_value(upper_bound); switch(chosen_rng) { case RNG_UNINITIALIZED: diff --git a/pdns/dns_random.hh b/pdns/dns_random.hh index 52f2812314..5c5e441555 100644 --- a/pdns/dns_random.hh +++ b/pdns/dns_random.hh @@ -22,6 +22,7 @@ #pragma once #include #include +#include void dns_random_init(const std::string& data = "", bool force_reinit = false); uint32_t dns_random(uint32_t n); @@ -47,4 +48,30 @@ namespace pdns { return dns_random(std::numeric_limits::max()); } }; + + /* minimum value that a PRNG should return for this upper bound to avoid a modulo bias */ + inline unsigned int random_minimum_acceptable_value(uint32_t upper_bound) + { + /* Parts of this code come from arc4random_uniform */ + /* To avoid "modulo bias" for some methods, calculate + minimum acceptable value for random number to improve + uniformity. + + On applicable rngs, we loop until the rng spews out + value larger than min, and then take modulo out of that. + */ + unsigned int min; +#if (ULONG_MAX > 0xffffffffUL) + min = 0x100000000UL % upper_bound; +#else + /* Calculate (2**32 % upper_bound) avoiding 64-bit math */ + if (upper_bound > 0x80000000) + min = 1 + ~upper_bound; /* 2**32 - upper_bound */ + else { + /* (2**32 - (x * 2)) % x == 2**32 % x when x <= 2**31 */ + min = ((0xffffffff - (upper_bound * 2)) + 1) % upper_bound; + } +#endif + return min; + } } diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index fd0b6a667e..0bbf87560d 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -99,7 +99,7 @@ struct IDState sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0; } IDState(const IDState& orig) = delete; IDState(IDState&& rhs) : - subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), packetCache(std::move(rhs.packetCache)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), cacheKeyUDP(rhs.cacheKeyUDP), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) + subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), packetCache(std::move(rhs.packetCache)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), cacheKeyUDP(rhs.cacheKeyUDP), origFD(rhs.origFD), backendFD(rhs.backendFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) { if (rhs.isInUse()) { throw std::runtime_error("Trying to move an in-use IDState"); @@ -140,6 +140,7 @@ struct IDState cacheKeyNoECS = rhs.cacheKeyNoECS; cacheKeyUDP = rhs.cacheKeyUDP; origFD = rhs.origFD; + backendFD = rhs.backendFD; delayMsec = rhs.delayMsec; #ifdef __SANITIZE_THREAD__ age.store(rhs.age.load()); @@ -249,6 +250,7 @@ struct IDState // DoH-only */ uint32_t cacheKeyUDP{0}; // 4 int origFD{-1}; // 4 + int backendFD{-1}; // 4 int delayMsec{0}; #ifdef __SANITIZE_THREAD__ std::atomic age{0}; diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index a1965531de..367de3b6dc 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -2779,6 +2779,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) g_socketUDPRecvBuffer = recv; }); + luaCtx.writeFunction("setRandomizedOutgoingSockets", [](bool randomized) { + DownstreamState::s_randomizeSockets = randomized; + }); + #if defined(HAVE_LIBSSL) luaCtx.writeFunction("loadTLSEngine", [client](const std::string& engineName, boost::optional defaultString) { if (client) { diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 3260371e8f..aa8d52d9a5 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -54,6 +54,7 @@ #include "dnsdist-lua.hh" #include "dnsdist-nghttp2.hh" #include "dnsdist-proxy-protocol.hh" +#include "dnsdist-random.hh" #include "dnsdist-rings.hh" #include "dnsdist-secpoll.hh" #include "dnsdist-tcp.hh" @@ -537,23 +538,6 @@ static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int return true; } -int pickBackendSocketForSending(std::shared_ptr& state) -{ - return state->sockets[state->socketsOffset++ % state->sockets.size()]; -} - -static void pickBackendSocketsReadyForReceiving(const std::shared_ptr& state, std::vector& ready) -{ - ready.clear(); - - if (state->sockets.size() == 1) { - ready.push_back(state->sockets[0]); - return ; - } - - (*state->mplexer.lock())->getAvailableFDs(ready, 1000); -} - void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol) { struct timespec ts; @@ -592,7 +576,7 @@ void responderThread(std::shared_ptr dss) for(;;) { try { - pickBackendSocketsReadyForReceiving(dss, sockets); + dss->pickSocketsReadyForReceiving(sockets); if (dss->isStopped()) { break; } @@ -639,7 +623,7 @@ void responderThread(std::shared_ptr dss) int origFD = ids->origFD; unsigned int qnameWireLength = 0; - if (!responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss->remote, qnameWireLength)) { + if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss->remote, qnameWireLength)) { continue; } @@ -1587,7 +1571,8 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct addProxyProtocol(dq); } - int fd = pickBackendSocketForSending(ss); + int fd = ss->pickSocketForSending(); + ids->backendFD = fd; ssize_t ret = udpClientSendRequestToBackend(ss, fd, query); if(ret < 0) { @@ -1755,16 +1740,6 @@ static void udpClientThread(ClientState* cs) } } - -uint16_t getRandomDNSID() -{ -#ifdef HAVE_LIBSODIUM - return randombytes_uniform(65536); -#else - return (random() % 65536); -#endif -} - boost::optional g_maxTCPClientThreads{boost::none}; pdns::stat16_t g_cacheCleaningDelay{60}; pdns::stat16_t g_cacheCleaningPercentage{100}; @@ -2269,17 +2244,10 @@ int main(int argc, char** argv) cerr<<"Unable to initialize crypto library"<(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t; + typedef std::function(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t; DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf, const std::string& sourceItfName); DownstreamState(const ComboAddress& remote_): DownstreamState(remote_, ComboAddress(), 0, std::string()) {} @@ -880,6 +880,9 @@ public: } bool passCrossProtocolQuery(std::unique_ptr&& cpq); + int pickSocketForSending(); + void pickSocketsReadyForReceiving(std::vector& ready); + dnsdist::Protocol getProtocol() const { if (isDoH()) { @@ -893,6 +896,8 @@ public: } return dnsdist::Protocol::DoUDP; } + + static bool s_randomizeSockets; }; using servers_t =vector>; @@ -1054,8 +1059,6 @@ extern std::vector> g_dnsCryptLocals; int handleDNSCryptQuery(PacketBuffer& packet, DNSCryptQuery& query, bool tcp, time_t now, PacketBuffer& response); bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::unique_ptr& dnsCryptQuery, time_t now, bool tcp); -uint16_t getRandomDNSID(); - #include "dnsdist-snmp.hh" extern bool g_snmpEnabled; @@ -1073,7 +1076,6 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data); void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); -int pickBackendSocketForSending(std::shared_ptr& state); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index fb09c81869..fa5d8aa5b1 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -132,6 +132,7 @@ dnsdist_SOURCES = \ connection-management.hh \ credentials.cc credentials.hh \ dns.cc dns.hh \ + dns_random.hh \ dnscrypt.cc dnscrypt.hh \ dnsdist-backend.cc \ dnsdist-cache.cc dnsdist-cache.hh \ @@ -165,6 +166,7 @@ dnsdist_SOURCES = \ dnsdist-protobuf.cc dnsdist-protobuf.hh \ dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ + dnsdist-random.cc dnsdist-random.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-secpoll.cc dnsdist-secpoll.hh \ @@ -248,6 +250,7 @@ testrunner_SOURCES = \ dnsdist-nghttp2.cc dnsdist-nghttp2.hh \ dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ + dnsdist-random.cc dnsdist-random.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-session-cache.cc dnsdist-session-cache.hh \ diff --git a/pdns/dnsdistdist/dns_random.hh b/pdns/dnsdistdist/dns_random.hh new file mode 120000 index 0000000000..cb207715c8 --- /dev/null +++ b/pdns/dnsdistdist/dns_random.hh @@ -0,0 +1 @@ +../dns_random.hh \ No newline at end of file diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index 2af756f8cc..cd8c2be217 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -22,6 +22,7 @@ #include "dnsdist.hh" #include "dnsdist-nghttp2.hh" +#include "dnsdist-random.hh" #include "dnsdist-tcp.hh" #include "dolog.hh" @@ -212,6 +213,38 @@ void DownstreamState::incCurrentConnectionsCount() } } +int DownstreamState::pickSocketForSending() +{ + size_t numberOfSockets = sockets.size(); + if (numberOfSockets == 1) { + return sockets[0]; + } + + size_t idx; + if (s_randomizeSockets) { + idx = dnsdist::getRandomValue(numberOfSockets); + } + else { + idx = socketsOffset++; + } + + return sockets[idx % numberOfSockets]; +} + +void DownstreamState::pickSocketsReadyForReceiving(std::vector& ready) +{ + ready.clear(); + + if (sockets.size() == 1) { + ready.push_back(sockets[0]); + return ; + } + + (*mplexer.lock())->getAvailableFDs(ready, 1000); +} + +bool DownstreamState::s_randomizeSockets{false}; + size_t ServerPool::countServers(bool upOnly) { size_t count = 0; diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index 50b9341ccc..a8c82d85a4 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -24,6 +24,7 @@ #include "tcpiohandler-mplexer.hh" #include "dnswriter.hh" #include "dolog.hh" +#include "dnsdist-random.hh" #include "dnsdist-tcp.hh" #include "dnsdist-nghttp2.hh" #include "dnsdist-session-cache.hh" @@ -325,7 +326,7 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared { try { - uint16_t queryID = getRandomDNSID(); + uint16_t queryID = dnsdist::getRandomDNSID(); DNSName checkName = ds->checkName; uint16_t checkType = ds->checkType.getCode(); uint16_t checkClass = ds->checkClass; diff --git a/pdns/dnsdistdist/dnsdist-random.cc b/pdns/dnsdistdist/dnsdist-random.cc new file mode 100644 index 0000000000..dc8a441f0d --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-random.cc @@ -0,0 +1,90 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "config.h" + +#include +#include +#include +#ifdef HAVE_LIBSODIUM +#include +#endif /* HAVE_LIBSODIUM */ +#ifdef HAVE_RAND_BYTES +#include +#endif /* HAVE_RAND_BYTES */ + +#include "dnsdist-random.hh" +#include "dns_random.hh" + +namespace dnsdist +{ +void initRandom() +{ +#ifdef HAVE_LIBSODIUM + srandom(randombytes_uniform(0xffffffff)); +#else + { + auto getSeed = []() { +#ifdef HAVE_RAND_BYTES + unsigned int seed; + if (RAND_bytes(reinterpret_cast(&seed), sizeof(seed)) == 1) { + return seed; + } +#endif /* HAVE_RAND_BYTES */ + struct timeval tv; + gettimeofday(&tv, 0); + return static_cast(tv.tv_sec ^ tv.tv_usec ^ getpid()); + }; + + srandom(getSeed()); + } +#endif +} + +uint32_t getRandomValue(uint32_t upperBound) +{ +#ifdef HAVE_LIBSODIUM + return randombytes_uniform(upperBound); +#else /* HAVE_LIBSODIUM */ + uint32_t result; + unsigned int min = pdns::random_minimum_acceptable_value(upperBound); +#ifdef HAVE_RAND_BYTES + do { + if (RAND_bytes(reinterpret_cast(&result), sizeof(result)) != 1) { + throw std::runtime_error("Error getting a random value via RAND_bytes"); + } + } while (result < min); + + return result % upperBound; +#endif /* HAVE_RAND_BYTES */ + do { + result = random(); + } while (result < min); + + return result % upperBound; +#endif /* HAVE_LIBSODIUM */ +} + +uint16_t getRandomDNSID() +{ + return getRandomValue(65536); +} +} diff --git a/pdns/dnsdistdist/dnsdist-random.hh b/pdns/dnsdistdist/dnsdist-random.hh new file mode 100644 index 0000000000..d7f0f031b9 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-random.hh @@ -0,0 +1,31 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include + +namespace dnsdist +{ +void initRandom(); +uint32_t getRandomValue(uint32_t upperBound); +uint16_t getRandomDNSID(); +} diff --git a/pdns/dnsdistdist/dnsdist-secpoll.cc b/pdns/dnsdistdist/dnsdist-secpoll.cc index 008302d111..4f5a9994bc 100644 --- a/pdns/dnsdistdist/dnsdist-secpoll.cc +++ b/pdns/dnsdistdist/dnsdist-secpoll.cc @@ -36,6 +36,7 @@ #include "sstuff.hh" #include "dnsdist.hh" +#include "dnsdist-random.hh" #ifndef PACKAGEVERSION #define PACKAGEVERSION PACKAGE_VERSION @@ -92,7 +93,7 @@ static std::string getSecPollStatus(const std::string& queriedName, int timeout= const DNSName& sentName = DNSName(queriedName); std::vector packet; DNSPacketWriter pw(packet, sentName, QType::TXT); - pw.getHeader()->id = getRandomDNSID(); + pw.getHeader()->id = dnsdist::getRandomDNSID(); pw.getHeader()->rd = 1; const auto& resolversForStub = getResolvers("/etc/resolv.conf"); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index c6e3100d64..e06d64737f 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -724,7 +724,8 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) } } - int fd = pickBackendSocketForSending(du->downstream); + int fd = du->downstream->pickSocketForSending(); + ids->backendFD = fd; try { /* you can't touch du after this line, unless the call returned a non-negative value, because it might already have been freed */