From: Remi Gacogne Date: Wed, 31 Mar 2021 15:22:21 +0000 (+0200) Subject: dnsdist: First working version of cross-protocol DoH -> TCP X-Git-Tag: dnsdist-1.7.0-alpha1~45^2~46 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2c0e81bbf0641ad0a9dc2228295c05250530c779;p=thirdparty%2Fpdns.git dnsdist: First working version of cross-protocol DoH -> TCP --- diff --git a/pdns/Makefile.am b/pdns/Makefile.am index 137de97c83..a17aa95e0a 100644 --- a/pdns/Makefile.am +++ b/pdns/Makefile.am @@ -1562,6 +1562,8 @@ fuzz_target_dnsdistcache_SOURCES = \ dns.cc dns.hh \ dnsdist-cache.cc dnsdist-cache.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ + dnsdist-idstate.hh \ + dnsdist-protocols.hh \ dnslabeltext.cc \ dnsname.cc dnsname.hh \ dnsparser.cc dnsparser.hh \ diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index 2ab26821c5..4a5599519d 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -622,7 +622,6 @@ const std::vector g_consoleKeywords{ { "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" }, { "setTCPDownstreamCleanupInterval", true, "interval", "minimum interval in seconds between two cleanups of the idle TCP downstream connections" }, { "setTCPInternalPipeBufferSize", true, "size", "Set the size in bytes of the internal buffer of the pipes used internally to distribute connections to TCP (and DoT) workers threads" }, - { "setTCPUseSinglePipe", true, "bool", "whether the incoming TCP connections should be put into a single queue instead of using per-thread queues. Defaults to false" }, { "setTCPRecvTimeout", true, "n", "set the read timeout on TCP connections from the client, in seconds" }, { "setTCPSendTimeout", true, "n", "set the write timeout on TCP connections from the client, in seconds" }, { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" }, diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh new file mode 100644 index 0000000000..69d882b746 --- /dev/null +++ b/pdns/dnsdist-idstate.hh @@ -0,0 +1,258 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include "config.h" +#include "dnsname.hh" +#include "dnsdist-protocols.hh" +#include "gettime.hh" +#include "iputils.hh" +#include "uuid-utils.hh" + +struct ClientState; +struct DOHUnit; +class DNSCryptQuery; +class DNSDistPacketCache; + +using QTag = std::unordered_map; + +struct StopWatch +{ + StopWatch(bool realTime=false): d_needRealTime(realTime) + { + } + + void start() { + if (gettime(&d_start, d_needRealTime) < 0) { + unixDie("Getting timestamp"); + } + } + + void set(const struct timespec& from) { + d_start = from; + } + + double udiff() const { + struct timespec now; + if (gettime(&now, d_needRealTime) < 0) { + unixDie("Getting timestamp"); + } + + return 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0; + } + + double udiffAndSet() { + struct timespec now; + if (gettime(&now, d_needRealTime) < 0) { + unixDie("Getting timestamp"); + } + + auto ret= 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0; + d_start = now; + return ret; + } + + struct timespec d_start{0,0}; +private: + bool d_needRealTime{false}; +}; + +/* g++ defines __SANITIZE_THREAD__ + clang++ supports the nice __has_feature(thread_sanitizer), + let's merge them */ +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define __SANITIZE_THREAD__ 1 +#endif +#endif + +struct IDState +{ + IDState(): sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;} + IDState(const IDState& orig) = delete; + IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) + { + if (rhs.isInUse()) { + throw std::runtime_error("Trying to move an in-use IDState"); + } + + uniqueId = std::move(rhs.uniqueId); +#ifdef __SANITIZE_THREAD__ + age.store(rhs.age.load()); +#else + age = rhs.age; +#endif + } + + IDState& operator=(IDState&& rhs) + { + if (isInUse()) { + throw std::runtime_error("Trying to overwrite an in-use IDState"); + } + + if (rhs.isInUse()) { + throw std::runtime_error("Trying to move an in-use IDState"); + } + + subnet = std::move(rhs.subnet); + origRemote = rhs.origRemote; + origDest = rhs.origDest; + hopRemote = rhs.hopRemote; + hopLocal = rhs.hopLocal; + qname = std::move(rhs.qname); + sentTime = rhs.sentTime; + dnsCryptQuery = std::move(rhs.dnsCryptQuery); + packetCache = std::move(rhs.packetCache); + qTag = std::move(rhs.qTag); + tempFailureTTL = std::move(rhs.tempFailureTTL); + cs = rhs.cs; + du = std::move(rhs.du); + cacheKey = rhs.cacheKey; + cacheKeyNoECS = rhs.cacheKeyNoECS; + origFD = rhs.origFD; + delayMsec = rhs.delayMsec; +#ifdef __SANITIZE_THREAD__ + age.store(rhs.age.load()); +#else + age = rhs.age; +#endif + qtype = rhs.qtype; + qclass = rhs.qclass; + origID = rhs.origID; + origFlags = rhs.origFlags; + cacheFlags = rhs.cacheFlags; + protocol = rhs.protocol; + uniqueId = std::move(rhs.uniqueId); + ednsAdded = rhs.ednsAdded; + ecsAdded = rhs.ecsAdded; + skipCache = rhs.skipCache; + destHarvested = rhs.destHarvested; + dnssecOK = rhs.dnssecOK; + useZeroScope = rhs.useZeroScope; + + return *this; + } + + static const int64_t unusedIndicator = -1; + + static bool isInUse(int64_t usageIndicator) + { + return usageIndicator != unusedIndicator; + } + + bool isInUse() const + { + return usageIndicator != unusedIndicator; + } + + /* return true if the value has been successfully replaced meaning that + no-one updated the usage indicator in the meantime */ + bool tryMarkUnused(int64_t expectedUsageIndicator) + { + return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator); + } + + /* mark as used no matter what, return true if the state was in use before */ + bool markAsUsed() + { + auto currentGeneration = generation++; + return markAsUsed(currentGeneration); + } + + /* mark as used no matter what, return true if the state was in use before */ + bool markAsUsed(int64_t currentGeneration) + { + int64_t oldUsage = usageIndicator.exchange(currentGeneration); + return oldUsage != unusedIndicator; + } + + /* We use this value to detect whether this state is in use. + For performance reasons we don't want to use a lock here, but that means + we need to be very careful when modifying this value. Modifications happen + from: + - one of the UDP or DoH 'client' threads receiving a query, selecting a backend + then picking one of the states associated to this backend (via the idOffset). + Most of the time this state should not be in use and usageIndicator is -1, but we + might not yet have received a response for the query previously associated to this + state, meaning that we will 'reuse' this state and erase the existing state. + If we ever receive a response for this state, it will be discarded. This is + mostly fine for UDP except that we still need to be careful in order to miss + the 'outstanding' counters, which should only be increased when we are picking + an empty state, and not when reusing ; + For DoH, though, we have dynamically allocated a DOHUnit object that needs to + be freed, as well as internal objects internals to libh2o. + - one of the UDP receiver threads receiving a response from a backend, picking + the corresponding state and sending the response to the client ; + - the 'healthcheck' thread scanning the states to actively discover timeouts, + mostly to keep some counters like the 'outstanding' one sane. + We previously based that logic on the origFD (FD on which the query was received, + and therefore from where the response should be sent) but this suffered from an + ABA problem since it was quite likely that a UDP 'client thread' would reset it to the + same value since we only have so much incoming sockets: + - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ; + - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname, + qtype and qclass match + - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ; + - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still + 5, except it's not the same 5 anymore and it overrides a fresh state. + We now use a 32-bit unsigned counter instead, which is incremented every time the state is set, + wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1 + when the state is unused and the value of our counter otherwise. + */ + boost::optional subnet{boost::none}; // 40 + ComboAddress origRemote; // 28 + ComboAddress origDest; // 28 + ComboAddress hopRemote; + ComboAddress hopLocal; + DNSName qname; // 24 + StopWatch sentTime; // 16 + std::shared_ptr dnsCryptQuery{nullptr}; // 16 + std::shared_ptr packetCache{nullptr}; // 16 + std::shared_ptr qTag{nullptr}; // 16 + boost::optional tempFailureTTL; // 8 + const ClientState* cs{nullptr}; // 8 + DOHUnit* du{nullptr}; // 8 + std::atomic usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8 + std::atomic generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4 + uint32_t cacheKey{0}; // 4 + uint32_t cacheKeyNoECS{0}; // 4 + int origFD{-1}; // 4 + int delayMsec{0}; +#ifdef __SANITIZE_THREAD__ + std::atomic age{0}; +#else + uint16_t age{0}; // 2 +#endif + uint16_t qtype{0}; // 2 + uint16_t qclass{0}; // 2 + uint16_t origID{0}; // 2 + uint16_t origFlags{0}; // 2 + uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2 + dnsdist::Protocol protocol; // 1 + boost::optional uniqueId{boost::none}; // 17 (placed here to reduce the space lost to padding) + bool ednsAdded{false}; + bool ecsAdded{false}; + bool skipCache{false}; + bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr + bool dnssecOK{false}; + bool useZeroScope{false}; +}; diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index b5c0703967..12df54ae93 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -1224,23 +1224,23 @@ private: bool d_hasV6; }; -static DnstapMessage::ProtocolType ProtocolToDNSTap(DNSQuestion::Protocol protocol) +static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol) { DnstapMessage::ProtocolType result; switch (protocol) { default: - case DNSQuestion::Protocol::DoUDP: - case DNSQuestion::Protocol::DNSCryptUDP: + case dnsdist::Protocol::DoUDP: + case dnsdist::Protocol::DNSCryptUDP: result = DnstapMessage::ProtocolType::DoUDP; break; - case DNSQuestion::Protocol::DoTCP: - case DNSQuestion::Protocol::DNSCryptTCP: + case dnsdist::Protocol::DoTCP: + case dnsdist::Protocol::DNSCryptTCP: result = DnstapMessage::ProtocolType::DoTCP; break; - case DNSQuestion::Protocol::DoT: + case dnsdist::Protocol::DoT: result = DnstapMessage::ProtocolType::DoT; break; - case DNSQuestion::Protocol::DoH: + case dnsdist::Protocol::DoH: result = DnstapMessage::ProtocolType::DoH; break; } diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index f213335e26..3ddfa44157 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -74,7 +74,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) }); luaCtx.registerFunction("getProtocol", [](const DNSQuestion& dq) { - return DNSQuestion::ProtocolToString(dq.getProtocol()); + return dnsdist::ProtocolToString(dq.getProtocol()); }); luaCtx.registerFunction("sendTrap", [](const DNSQuestion& dq, boost::optional reason) { @@ -252,7 +252,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) }); luaCtx.registerFunction("getProtocol", [](const DNSResponse& dr) { - return DNSQuestion::ProtocolToString(dr.getProtocol()); + return dnsdist::ProtocolToString(dr.getProtocol()); }); luaCtx.registerFunction("sendTrap", [](const DNSResponse& dr, boost::optional reason) { diff --git a/pdns/dnsdist-lua-inspection.cc b/pdns/dnsdist-lua-inspection.cc index b1d6a38b33..4c9024c8bb 100644 --- a/pdns/dnsdist-lua-inspection.cc +++ b/pdns/dnsdist-lua-inspection.cc @@ -23,6 +23,7 @@ #include "dnsdist-lua.hh" #include "dnsdist-dynblocks.hh" #include "dnsdist-rings.hh" +#include "dnsdist-tcp.hh" #include "statnode.hh" @@ -597,9 +598,6 @@ void setupLuaInspection(LuaContext& luaCtx) ret << (fmt % g_tcpclientthreads->getThreadsCount() % (g_maxTCPClientThreads ? *g_maxTCPClientThreads : 0) % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections) << endl; ret << endl; - ret << "Query distribution mode is: " << std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") << endl; - ret << endl; - ret << "Frontends:" << endl; fmt = boost::format("%-3d %-20.20s %-20d %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f %-20d %-20d %-25d %-25d %-15d %-15d %-15d %-15d %-15d"); ret << (fmt % "#" % "Address" % "Connections" % "Max concurrent conn" % "Died reading query" % "Died sending response" % "Gave up" % "Client timeouts" % "Downstream timeouts" % "Avg queries/conn" % "Avg duration" % "TLS new sessions" % "TLS Resumptions" % "TLS unknown ticket keys" % "TLS inactive ticket keys" % "TLS 1.0" % "TLS 1.1" % "TLS 1.2" % "TLS 1.3" % "TLS other") << endl; diff --git a/pdns/dnsdist-lua-rules.cc b/pdns/dnsdist-lua-rules.cc index 9c9bec0afa..12e44a042c 100644 --- a/pdns/dnsdist-lua-rules.cc +++ b/pdns/dnsdist-lua-rules.cc @@ -444,7 +444,7 @@ void setupLuaRules(LuaContext& luaCtx) sw.start(); for(int n=0; n < times; ++n) { item& i = items[n % items.size()]; - DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, DNSQuestion::Protocol::DoUDP, &sw.d_start); + DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, dnsdist::Protocol::DoUDP, &sw.d_start); if (rule->matches(&dq)) { matches++; } diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 879b785e5a..b1f0925fc7 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1843,15 +1843,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) g_hashperturb = pertub; }); - luaCtx.writeFunction("setTCPUseSinglePipe", [](bool flag) { - if (g_configurationDone) { - g_outputBuffer="setTCPUseSinglePipe() cannot be used at runtime!\n"; - return; - } - setLuaSideEffect(); - g_useTCPSinglePipe = flag; - }); - luaCtx.writeFunction("setTCPInternalPipeBufferSize", [](size_t size) { g_tcpInternalPipeBufferSize = size; }); luaCtx.writeFunction("snmpAgent", [client,configCheck](bool enableTraps, boost::optional daemonSocket) { diff --git a/pdns/dnsdist-protobuf.cc b/pdns/dnsdist-protobuf.cc index 8e91640096..e2e6c04106 100644 --- a/pdns/dnsdist-protobuf.cc +++ b/pdns/dnsdist-protobuf.cc @@ -124,7 +124,7 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const m.setTime(ts.tv_sec, ts.tv_nsec / 1000); } - m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, (d_dq.getProtocol() == DNSQuestion::Protocol::DoH) ? true : d_dq.overTCP(), d_bytes ? *d_bytes : d_dq.getData().size()); + m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, (d_dq.getProtocol() == dnsdist::Protocol::DoH) ? true : d_dq.overTCP(), d_bytes ? *d_bytes : d_dq.getData().size()); if (d_serverIdentity) { m.setServerIdentity(*d_serverIdentity); diff --git a/pdns/dnsdist-protocols.hh b/pdns/dnsdist-protocols.hh new file mode 100644 index 0000000000..271c24178b --- /dev/null +++ b/pdns/dnsdist-protocols.hh @@ -0,0 +1,31 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include +#include + +namespace dnsdist { + enum class Protocol : uint8_t { DoUDP, DoTCP, DNSCryptUDP, DNSCryptTCP, DoT, DoH }; + + const std::string& ProtocolToString(Protocol proto); +} diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 0a7da95372..cea13fc631 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -28,6 +28,7 @@ #include "dnsdist-ecs.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" +#include "dnsdist-tcp.hh" #include "dnsdist-tcp-downstream.hh" #include "dnsdist-tcp-upstream.hh" #include "dnsdist-xpf.hh" @@ -72,7 +73,6 @@ uint64_t g_maxTCPQueuedConnections{1000}; uint16_t g_downstreamTCPCleanupInterval{60}; int g_tcpRecvTimeout{2}; int g_tcpSendTimeout{2}; -bool g_useTCPSinglePipe{false}; std::atomic g_tcpStatesDumpRequested{0}; class DownstreamConnectionsManager @@ -110,7 +110,7 @@ public: } } - return std::make_shared(ds, now); + return std::make_shared(ds, mplexer, now); } static void releaseDownstreamConnection(std::shared_ptr&& conn) @@ -248,100 +248,80 @@ std::shared_ptr IncomingTCPConnectionState::getDownstrea return downstream; } -static void tcpClientThread(int pipefd); +static void tcpClientThread(int pipefd, int crossProtocolPipeFD); -TCPClientCollection::TCPClientCollection(size_t maxThreads, bool useSinglePipe): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads), d_singlePipe{-1,-1}, d_useSinglePipe(useSinglePipe) +TCPClientCollection::TCPClientCollection(size_t maxThreads): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) { - if (d_useSinglePipe) { - if (pipe(d_singlePipe) < 0) { - int err = errno; - throw std::runtime_error("Error creating the TCP single communication pipe: " + stringerror(err)); - } - - if (!setNonBlocking(d_singlePipe[0])) { - int err = errno; - close(d_singlePipe[0]); - close(d_singlePipe[1]); - throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err)); - } - - if (!setNonBlocking(d_singlePipe[1])) { - int err = errno; - close(d_singlePipe[0]); - close(d_singlePipe[1]); - throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + stringerror(err)); - } - - if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(d_singlePipe[0]) < g_tcpInternalPipeBufferSize) { - setPipeBufferSize(d_singlePipe[0], g_tcpInternalPipeBufferSize); - } - } } void TCPClientCollection::addTCPClientThread() { - int pipefds[2] = { -1, -1}; - - vinfolog("Adding TCP Client thread"); - - if (d_useSinglePipe) { - pipefds[0] = d_singlePipe[0]; - pipefds[1] = d_singlePipe[1]; - } - else { - if (pipe(pipefds) < 0) { - errlog("Error creating the TCP thread communication pipe: %s", stringerror()); - return; + auto preparePipe = [](int fds[2], const std::string& type) -> bool { + if (pipe(fds) < 0) { + errlog("Error creating the TCP thread %s pipe: %s", type, stringerror()); + return false; } - if (!setNonBlocking(pipefds[0])) { + if (!setNonBlocking(fds[0])) { int err = errno; - close(pipefds[0]); - close(pipefds[1]); - errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err)); - return; + close(fds[0]); + close(fds[1]); + errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; } - if (!setNonBlocking(pipefds[1])) { + if (!setNonBlocking(fds[1])) { int err = errno; - close(pipefds[0]); - close(pipefds[1]); - errlog("Error setting the TCP thread communication pipe non-blocking: %s", stringerror(err)); - return; + close(fds[0]); + close(fds[1]); + errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; } - if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(pipefds[0]) < g_tcpInternalPipeBufferSize) { - setPipeBufferSize(pipefds[0], g_tcpInternalPipeBufferSize); + if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) { + setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize); } + + return true; + }; + + int pipefds[2] = { -1, -1}; + if (!preparePipe(pipefds, "communication")) { + return; } + int crossProtocolFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolFDs, "cross-protocol")) { + return; + } + + vinfolog("Adding TCP Client thread"); + { std::lock_guard lock(d_mutex); if (d_numthreads >= d_tcpclientthreads.size()) { vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size()); - if (!d_useSinglePipe) { - close(pipefds[0]); - close(pipefds[1]); - } + close(pipefds[0]); + close(pipefds[1]); return; } + /* from now on this side of the pipe will be managed by that object, + no need to worry about it */ + TCPWorkerThread worker(pipefds[1], crossProtocolFDs[1]); try { - std::thread t1(tcpClientThread, pipefds[0]); + std::thread t1(tcpClientThread, pipefds[0], crossProtocolFDs[0]); t1.detach(); } catch (const std::runtime_error& e) { /* the thread creation failed, don't leak */ errlog("Error creating a TCP thread: %s", e.what()); - if (!d_useSinglePipe) { - close(pipefds[0]); - close(pipefds[1]); - } + close(pipefds[0]); return; } - d_tcpclientthreads.at(d_numthreads) = pipefds[1]; + d_tcpclientthreads.at(d_numthreads) = std::move(worker); ++d_numthreads; } } @@ -369,7 +349,7 @@ static IOState sendQueuedResponses(std::shared_ptr& static void handleResponseSent(std::shared_ptr& state, const TCPResponse& currentResponse) { - if (state->d_isXFR || currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) { + if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) { return; } @@ -399,6 +379,16 @@ static void handleResponseSent(std::shared_ptr& stat } } +static void prependSizeToTCPQuery(PacketBuffer& buffer) +{ + uint16_t queryLen = buffer.size(); + const uint8_t sizeBytes[] = { static_cast(queryLen / 256), static_cast(queryLen % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + buffer.insert(buffer.begin(), sizeBytes, sizeBytes + 2); +} + bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) { if (d_hadErrors) { @@ -406,11 +396,6 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) return false; } - if (d_isXFR) { - DEBUGLOG("not accepting new queries because used for XFR"); - return false; - } - if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { DEBUGLOG("not accepting new queries because we already have "<d_maxInFlightQueriesPerConn); return false; @@ -434,9 +419,6 @@ void IncomingTCPConnectionState::resetForNewQuery() d_buffer.resize(sizeof(uint16_t)); d_currentPos = 0; d_querySize = 0; - d_xfrMasterSerial = 0; - d_xfrSerialCount = 0; - d_xfrMasterSerialCount = 0; d_state = State::waitingForQuery; } @@ -548,8 +530,10 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptr state, const struct timeval& now, TCPResponse&& response) +void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response) { + std::shared_ptr state = shared_from_this(); + if (response.d_connection && response.d_connection->isIdle()) { // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool yet, no one else will be able to use it anyway if (response.d_connection->canBeReused()) { @@ -680,12 +664,12 @@ static void handleQuery(std::shared_ptr& state, cons uint16_t qtype, qclass; unsigned int qnameWireLength = 0; DNSName qname(reinterpret_cast(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - DNSQuestion::Protocol protocol = DNSQuestion::Protocol::DoTCP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoTCP; if (dnsCryptQuery) { - protocol = DNSQuestion::Protocol::DNSCryptTCP; + protocol = dnsdist::Protocol::DNSCryptTCP; } else if (state->d_handler.isTLS()) { - protocol = DNSQuestion::Protocol::DoT; + protocol = dnsdist::Protocol::DoT; } DNSQuestion dq(&qname, qtype, qclass, &state->d_proxiedDestination, &state->d_proxiedRemote, state->d_buffer, protocol, &queryRealTime); @@ -697,8 +681,7 @@ static void handleQuery(std::shared_ptr& state, cons dq.proxyProtocolValues = make_unique>(*state->d_proxyProtocolValues); } - state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR); - if (state->d_isXFR) { + if (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR) { dq.skipCache = true; } @@ -731,15 +714,9 @@ static void handleQuery(std::shared_ptr& state, cons setIDStateFromDNSQuestion(ids, dq, std::move(qname)); ids.origID = ntohs(dh->id); - uint16_t queryLen = state->d_buffer.size(); - const uint8_t sizeBytes[] = { static_cast(queryLen / 256), static_cast(queryLen % 256) }; - /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes - that could occur if we had to deal with the size during the processing, - especially alignment issues */ - state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2); + prependSizeToTCPQuery(state->d_buffer); auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); - downstreamConnection->assignToClientConnection(state, state->d_isXFR); bool proxyProtocolPayloadAdded = false; std::string proxyProtocolPayload; @@ -772,7 +749,8 @@ static void handleQuery(std::shared_ptr& state, cons ++state->d_currentQueriesCount; vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName()); - downstreamConnection->queueQuery(std::move(query), downstreamConnection); + std::shared_ptr incoming = state; + downstreamConnection->queueQuery(incoming, std::move(query)); } void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) @@ -1034,8 +1012,10 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_lastIOBlocked); } -void IncomingTCPConnectionState::notifyIOError(std::shared_ptr& state, IDState&& query, const struct timeval& now) +void IncomingTCPConnectionState::notifyIOError(IDState&& query, const struct timeval& now) { + std::shared_ptr state = shared_from_this(); + --state->d_currentQueriesCount; state->d_hadErrors = true; @@ -1062,8 +1042,9 @@ void IncomingTCPConnectionState::notifyIOError(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response) +void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response) { + std::shared_ptr state = shared_from_this(); queueResponse(state, now, std::move(response)); } @@ -1124,14 +1105,56 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param IncomingTCPConnectionState::handleIO(state, now); } - catch(...) { + catch (...) { delete citmp; citmp = nullptr; throw; } } -static void tcpClientThread(int pipefd) +static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) +{ + auto threadData = boost::any_cast(param); + CrossProtocolQuery* tmp{nullptr}; + + ssize_t got = read(pipefd, &tmp, sizeof(tmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(tmp)) { + throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + try { + struct timeval now; + gettimeofday(&now, nullptr); + + auto query = std::move(tmp->query); + auto downstreamServer = std::move(tmp->downstream); + std::shared_ptr tqs = tmp->getTCPQuerySender(); + delete tmp; + tmp = nullptr; + + auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); + +#warning FIXME: what if a proxy protocol payload was inserted? + prependSizeToTCPQuery(query.d_buffer); + downstream->queueQuery(tqs, std::move(query)); + } + catch (...) { + delete tmp; + tmp = nullptr; + throw; + } +} + +static void tcpClientThread(int pipefd, int crossProtocolPipeFD) { /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ @@ -1141,6 +1164,8 @@ static void tcpClientThread(int pipefd) TCPClientThreadData data; data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); + data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data); + struct timeval now; gettimeofday(&now, nullptr); time_t lastTCPCleanup = now.tv_sec; @@ -1238,7 +1263,6 @@ void tcpAcceptorThread(ClientState* cs) auto acl = g_ACL.getLocal(); for(;;) { - bool queuedCounterIncremented = false; std::unique_ptr ci; tcpClientCountIncremented = false; try { @@ -1294,23 +1318,7 @@ void tcpAcceptorThread(ClientState* cs) vinfolog("Got TCP connection from %s", remote.toStringWithPort()); ci->remote = remote; - int pipe = g_tcpclientthreads->getThread(); - if (pipe >= 0) { - queuedCounterIncremented = true; - auto tmp = ci.release(); - try { - // throws on failure - writen2WithTimeout(pipe, &tmp, sizeof(tmp), timeval{0,0}); - } - catch (...) { - delete tmp; - tmp = nullptr; - throw; - } - } - else { - g_tcpclientthreads->decrementQueuedCount(); - queuedCounterIncremented = false; + if (!g_tcpclientthreads->passConnectionToThread(std::move(ci))) { if (tcpClientCountIncremented) { decrementTCPClientCount(remote); } @@ -1321,9 +1329,6 @@ void tcpAcceptorThread(ClientState* cs) if (tcpClientCountIncremented) { decrementTCPClientCount(remote); } - if (queuedCounterIncremented) { - g_tcpclientthreads->decrementQueuedCount(); - } } catch (...){} } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index cf96c599ad..f71ecb878a 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -54,6 +54,7 @@ #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" #include "dnsdist-secpoll.hh" +#include "dnsdist-tcp.hh" #include "dnsdist-web.hh" #include "dnsdist-xpf.hh" @@ -78,8 +79,8 @@ /* the RuleAction plan Set of Rules, if one matches, it leads to an Action - Both rules and actions could conceivably be Lua based. - On the C++ side, both could be inherited from a class Rule and a class Action, + Both rules and actions could conceivably be Lua based. + On the C++ side, both could be inherited from a class Rule and a class Action, on the Lua side we can't do that. */ using std::thread; @@ -107,7 +108,7 @@ GlobalStateHolder g_pools; size_t g_udpVectorSize{1}; /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread. - Then we have a bunch of connected sockets for talking to downstream servers. + Then we have a bunch of connected sockets for talking to downstream servers. We send directly to those sockets. For the return path, per downstream server we have a thread that listens to responses. @@ -115,7 +116,7 @@ size_t g_udpVectorSize{1}; Per socket there is an array of 2^16 states, when we send out a packet downstream, we note there the original requestor and the original id. The new ID is the offset in the array. - When an answer comes in on a socket, we look up the offset by the id, and lob it to the + When an answer comes in on a socket, we look up the offset by the id, and lob it to the original requestor. IDs are assigned by atomic increments of the socket offset. @@ -633,6 +634,16 @@ void responderThread(std::shared_ptr dss) dh->id = ids->origID; + /* don't call processResponse on a truncated answer for DoH, we will retry over TCP */ + if (du && dh->tc) { +#ifdef HAVE_DNS_OVER_HTTPS + // DoH query + cerr<<"truncated answer for DoH"<handleUDPResponse(std::move(response), std::move(*ids)); +#endif + continue; + } + DNSResponse dr = makeDNSResponseFromIDState(*ids, response); if (dh->tc && g_truncateTC) { truncateTC(response, dr.getMaximumSize(), qnameWireLength); @@ -647,27 +658,11 @@ void responderThread(std::shared_ptr dss) if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query - du->response = std::move(response); - static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(du->rsock, &du, sizeof(du)); - if (sent != sizeof(du)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - ++g_stats.dohResponsePipeFull; - vinfolog("Unable to pass a DoH response to the DoH worker thread because the pipe is full"); - } - else { - vinfolog("Unable to pass a DoH response to the DoH worker thread because we couldn't write to the pipe: %s", stringerror()); - } - - /* at this point we have the only remaining pointer on this - DOHUnit object since we did set ids->du to nullptr earlier, - except if we got the response before the pointer could be - released by the frontend */ - du->release(); - } -#endif /* HAVE_DNS_OVER_HTTPS */ + du->handleUDPResponse(std::move(response), IDState()); +#endif du = nullptr; } + else { ComboAddress empty; empty.sin4.sin_family = 0; @@ -889,7 +884,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::Refused: vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort()); updateBlockStats(); - + dq.getHeader()->rcode = RCode::Refused; dq.getHeader()->qr = true; return true; @@ -954,7 +949,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); - + vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); dq.getHeader()->tc = true; dq.getHeader()->qr = true; @@ -1210,7 +1205,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing // ECS option, which would make it unsuitable for the zero-scope feature. if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) { - if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == DNSQuestion::Protocol::DoH, allowExpired)) { + if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == dnsdist::Protocol::DoH, allowExpired)) { if (!prepareOutgoingResponse(holders, cs, dq, true)) { return ProcessQueryResult::Drop; @@ -1232,7 +1227,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } if (dq.packetCache && !dq.skipCache) { - if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == DNSQuestion::Protocol::DoH, allowExpired)) { + if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, !dq.overTCP() || dq.getProtocol() == dnsdist::Protocol::DoH, allowExpired)) { restoreFlags(dq.getHeader(), dq.origFlags); @@ -1334,7 +1329,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct uint16_t qtype, qclass; unsigned int qnameWireLength = 0; DNSName qname(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, dnsCryptQuery ? DNSQuestion::Protocol::DNSCryptUDP : DNSQuestion::Protocol::DoUDP, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, dnsCryptQuery ? dnsdist::Protocol::DNSCryptUDP : dnsdist::Protocol::DoUDP, &queryRealTime); dq.dnsCryptQuery = std::move(dnsCryptQuery); if (!proxyProtocolValues.empty()) { dq.proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); @@ -1716,7 +1711,7 @@ static void healthChecksThread() dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta); dss->prev.queries.store(dss->queries.load()); dss->prev.reuseds.store(dss->reuseds.load()); - + for (IDState& ids : dss->idStates) { // timeouts int64_t usageIndicator = ids.usageIndicator; if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) { @@ -1749,7 +1744,7 @@ static void healthChecksThread() fake.id = ids.origID; g_rings.insertResponse(ts, ids.origRemote, ids.qname, ids.qtype, std::numeric_limits::max(), 0, fake, dss->remote); - } + } } } @@ -1969,7 +1964,7 @@ static void setUpLocalBind(std::unique_ptr& cs) cs->ready = true; } -struct +struct { vector locals; vector remotes; @@ -2055,7 +2050,7 @@ int main(int argc, char** argv) srandom(tv.tv_sec ^ tv.tv_usec ^ getpid()); g_hashperturb=random(); } - + #endif ComboAddress clientAddress = ComboAddress(); g_cmdLine.config=SYSCONFDIR "/dnsdist.conf"; @@ -2403,7 +2398,7 @@ int main(int argc, char** argv) g_maxTCPClientThreads = 1; } - g_tcpclientthreads = std::unique_ptr(new TCPClientCollection(*g_maxTCPClientThreads, g_useTCPSinglePipe)); + g_tcpclientthreads = std::make_unique(*g_maxTCPClientThreads); for (auto& t : todo) { t(); @@ -2481,7 +2476,7 @@ int main(int argc, char** argv) thread stattid(maintThread); stattid.detach(); - + thread healththread(healthChecksThread); thread dynBlockMaintThread(dynBlockMaintenanceThread); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index fadd2345e4..272c9498e9 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -38,10 +38,10 @@ #include "dnsdist-cache.hh" #include "dnsdist-dynbpf.hh" #include "dnsdist-lbpolicies.hh" +#include "dnsdist-protocols.hh" #include "dnsname.hh" #include "doh.hh" #include "ednsoptions.hh" -#include "gettime.hh" #include "iputils.hh" #include "misc.hh" #include "mplexer.hh" @@ -52,7 +52,6 @@ #include "proxy-protocol.hh" #include "stat_t.hh" -void carbonDumpThread(); uint64_t uptimeOfProcess(const std::string& str); extern uint16_t g_ECSSourcePrefixV4; @@ -63,14 +62,7 @@ using QTag = std::unordered_map; struct DNSQuestion { - enum class Protocol : uint8_t { DoUDP, DoTCP, DNSCryptUDP, DNSCryptTCP, DoT, DoH }; - static const std::string& ProtocolToString(Protocol proto) - { - static const std::vector values = { "Do53 UDP", "Do53 TCP", "DNSCrypt UDP", "DNSCrypt TCP", "DNS over TLS", "DNS over HTTPS" }; - return values.at(static_cast(proto)); - } - - DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, Protocol proto, const struct timespec* queryTime_): + DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_): data(data_), qname(name), local(lc), remote(rem), queryTime(queryTime_), tempFailureTTL(boost::none), qtype(type), qclass(class_), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), protocol(proto), ecsOverride(g_ECSOverride) { const uint16_t* flags = getFlagsFromDNSHeader(getHeader()); origFlags = *flags; @@ -119,14 +111,14 @@ struct DNSQuestion return 4096; } - Protocol getProtocol() const + dnsdist::Protocol getProtocol() const { return protocol; } bool overTCP() const { - return !(protocol == Protocol::DoUDP || protocol == Protocol::DNSCryptUDP); + return !(protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP); } protected: @@ -162,7 +154,7 @@ public: uint16_t ecsPrefixLength; uint16_t origFlags; uint16_t cacheFlags{0}; /* DNS flags as sent to the backend */ - const Protocol protocol; + const dnsdist::Protocol protocol; uint8_t ednsRCode{0}; bool skipCache{false}; bool ecsOverride; @@ -177,7 +169,7 @@ public: struct DNSResponse : DNSQuestion { - DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, DNSQuestion::Protocol proto, const struct timespec* queryTime_): + DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_): DNSQuestion(name, type, class_, lc, rem, data_, proto, queryTime_) { } DNSResponse(const DNSResponse&) = delete; DNSResponse& operator=(const DNSResponse&) = delete; @@ -420,44 +412,7 @@ struct DNSDistStats extern struct DNSDistStats g_stats; void doLatencyStats(double udiff); - -struct StopWatch -{ - StopWatch(bool realTime=false): d_needRealTime(realTime) - { - } - struct timespec d_start{0,0}; - bool d_needRealTime{false}; - - void start() { - if(gettime(&d_start, d_needRealTime) < 0) - unixDie("Getting timestamp"); - - } - - void set(const struct timespec& from) { - d_start = from; - } - - double udiff() const { - struct timespec now; - if(gettime(&now, d_needRealTime) < 0) - unixDie("Getting timestamp"); - - return 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0; - } - - double udiffAndSet() { - struct timespec now; - if(gettime(&now, d_needRealTime) < 0) - unixDie("Getting timestamp"); - - auto ret= 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0; - d_start = now; - return ret; - } - -}; +#include "dnsdist-idstate.hh" class BasicQPSLimiter { @@ -568,189 +523,6 @@ private: bool d_passthrough{true}; }; -struct ClientState; - -/* g++ defines __SANITIZE_THREAD__ - clang++ supports the nice __has_feature(thread_sanitizer), - let's merge them */ -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) -#define __SANITIZE_THREAD__ 1 -#endif -#endif - -struct IDState -{ - IDState(): sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;} - IDState(const IDState& orig) = delete; - IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) - { - if (rhs.isInUse()) { - throw std::runtime_error("Trying to move an in-use IDState"); - } - - uniqueId = std::move(rhs.uniqueId); -#ifdef __SANITIZE_THREAD__ - age.store(rhs.age.load()); -#else - age = rhs.age; -#endif - } - - IDState& operator=(IDState&& rhs) - { - if (isInUse()) { - throw std::runtime_error("Trying to overwrite an in-use IDState"); - } - - if (rhs.isInUse()) { - throw std::runtime_error("Trying to move an in-use IDState"); - } - - subnet = std::move(rhs.subnet); - origRemote = rhs.origRemote; - origDest = rhs.origDest; - hopRemote = rhs.hopRemote; - hopLocal = rhs.hopLocal; - qname = std::move(rhs.qname); - sentTime = rhs.sentTime; - dnsCryptQuery = std::move(rhs.dnsCryptQuery); - packetCache = std::move(rhs.packetCache); - qTag = std::move(rhs.qTag); - tempFailureTTL = std::move(rhs.tempFailureTTL); - cs = rhs.cs; - du = std::move(rhs.du); - cacheKey = rhs.cacheKey; - cacheKeyNoECS = rhs.cacheKeyNoECS; - origFD = rhs.origFD; - delayMsec = rhs.delayMsec; -#ifdef __SANITIZE_THREAD__ - age.store(rhs.age.load()); -#else - age = rhs.age; -#endif - qtype = rhs.qtype; - qclass = rhs.qclass; - origID = rhs.origID; - origFlags = rhs.origFlags; - cacheFlags = rhs.cacheFlags; - protocol = rhs.protocol; - uniqueId = std::move(rhs.uniqueId); - ednsAdded = rhs.ednsAdded; - ecsAdded = rhs.ecsAdded; - skipCache = rhs.skipCache; - destHarvested = rhs.destHarvested; - dnssecOK = rhs.dnssecOK; - useZeroScope = rhs.useZeroScope; - - return *this; - } - - static const int64_t unusedIndicator = -1; - - static bool isInUse(int64_t usageIndicator) - { - return usageIndicator != unusedIndicator; - } - - bool isInUse() const - { - return usageIndicator != unusedIndicator; - } - - /* return true if the value has been successfully replaced meaning that - no-one updated the usage indicator in the meantime */ - bool tryMarkUnused(int64_t expectedUsageIndicator) - { - return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator); - } - - /* mark as used no matter what, return true if the state was in use before */ - bool markAsUsed() - { - auto currentGeneration = generation++; - return markAsUsed(currentGeneration); - } - - /* mark as used no matter what, return true if the state was in use before */ - bool markAsUsed(int64_t currentGeneration) - { - int64_t oldUsage = usageIndicator.exchange(currentGeneration); - return oldUsage != unusedIndicator; - } - - /* We use this value to detect whether this state is in use. - For performance reasons we don't want to use a lock here, but that means - we need to be very careful when modifying this value. Modifications happen - from: - - one of the UDP or DoH 'client' threads receiving a query, selecting a backend - then picking one of the states associated to this backend (via the idOffset). - Most of the time this state should not be in use and usageIndicator is -1, but we - might not yet have received a response for the query previously associated to this - state, meaning that we will 'reuse' this state and erase the existing state. - If we ever receive a response for this state, it will be discarded. This is - mostly fine for UDP except that we still need to be careful in order to miss - the 'outstanding' counters, which should only be increased when we are picking - an empty state, and not when reusing ; - For DoH, though, we have dynamically allocated a DOHUnit object that needs to - be freed, as well as internal objects internals to libh2o. - - one of the UDP receiver threads receiving a response from a backend, picking - the corresponding state and sending the response to the client ; - - the 'healthcheck' thread scanning the states to actively discover timeouts, - mostly to keep some counters like the 'outstanding' one sane. - We previously based that logic on the origFD (FD on which the query was received, - and therefore from where the response should be sent) but this suffered from an - ABA problem since it was quite likely that a UDP 'client thread' would reset it to the - same value since we only have so much incoming sockets: - - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ; - - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname, - qtype and qclass match - - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ; - - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still - 5, except it's not the same 5 anymore and it overrides a fresh state. - We now use a 32-bit unsigned counter instead, which is incremented every time the state is set, - wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1 - when the state is unused and the value of our counter otherwise. - */ - boost::optional subnet{boost::none}; // 40 - ComboAddress origRemote; // 28 - ComboAddress origDest; // 28 - ComboAddress hopRemote; - ComboAddress hopLocal; - DNSName qname; // 24 - StopWatch sentTime; // 16 - std::shared_ptr dnsCryptQuery{nullptr}; // 16 - std::shared_ptr packetCache{nullptr}; // 16 - std::shared_ptr qTag{nullptr}; // 16 - boost::optional tempFailureTTL; // 8 - const ClientState* cs{nullptr}; // 8 - DOHUnit* du{nullptr}; // 8 - std::atomic usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8 - std::atomic generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4 - uint32_t cacheKey{0}; // 4 - uint32_t cacheKeyNoECS{0}; // 4 - int origFD{-1}; // 4 - int delayMsec{0}; -#ifdef __SANITIZE_THREAD__ - std::atomic age{0}; -#else - uint16_t age{0}; // 2 -#endif - uint16_t qtype{0}; // 2 - uint16_t qclass{0}; // 2 - uint16_t origID{0}; // 2 - uint16_t origFlags{0}; // 2 - uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2 - DNSQuestion::Protocol protocol; // 1 - boost::optional uniqueId{boost::none}; // 17 (placed here to reduce the space lost to padding) - bool ednsAdded{false}; - bool ecsAdded{false}; - bool skipCache{false}; - bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr - bool dnssecOK{false}; - bool useZeroScope{false}; -}; - typedef std::unordered_map QueryCountRecords; typedef std::function(const DNSQuestion* dq)> QueryCountFilter; struct QueryCount { @@ -781,15 +553,15 @@ struct ClientState std::string interface; stat_t queries{0}; mutable stat_t responses{0}; - stat_t tcpDiedReadingQuery{0}; - stat_t tcpDiedSendingResponse{0}; - stat_t tcpGaveUp{0}; - stat_t tcpClientTimeouts{0}; - stat_t tcpDownstreamTimeouts{0}; + mutable stat_t tcpDiedReadingQuery{0}; + mutable stat_t tcpDiedSendingResponse{0}; + mutable stat_t tcpGaveUp{0}; + mutable stat_t tcpClientTimeouts{0}; + mutable stat_t tcpDownstreamTimeouts{0}; /* current number of connections to this frontend */ - stat_t tcpCurrentConnections{0}; + mutable stat_t tcpCurrentConnections{0}; /* maximum number of concurrent connections to this frontend reached */ - stat_t tcpMaxConcurrentConnections{0}; + mutable stat_t tcpMaxConcurrentConnections{0}; stat_t tlsNewSessions{0}; // A new TLS session has been negotiated, no resumption stat_t tlsResumptions{0}; // A TLS session has been resumed, either via session id or via a TLS ticket stat_t tlsUnknownTicketKey{0}; // A TLS ticket has been presented but we don't have the associated key (might have expired) @@ -875,49 +647,6 @@ struct ClientState } }; -class TCPClientCollection { - std::vector d_tcpclientthreads; - stat_t d_numthreads{0}; - stat_t d_pos{0}; - stat_t d_queued{0}; - const uint64_t d_maxthreads{0}; - std::mutex d_mutex; - int d_singlePipe[2]; - const bool d_useSinglePipe; -public: - - TCPClientCollection(size_t maxThreads, bool useSinglePipe=false); - int getThread() - { - if (d_numthreads == 0) { - throw std::runtime_error("No TCP worker thread yet"); - } - - uint64_t pos = d_pos++; - ++d_queued; - return d_tcpclientthreads.at(pos % d_numthreads); - } - bool hasReachedMaxThreads() const - { - return d_numthreads >= d_maxthreads; - } - uint64_t getThreadsCount() const - { - return d_numthreads; - } - uint64_t getQueuedCount() const - { - return d_queued; - } - void decrementQueuedCount() - { - --d_queued; - } - void addTCPClientThread(); -}; - -extern std::unique_ptr g_tcpclientthreads; - struct DownstreamState { typedef std::function(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t; @@ -1269,3 +998,5 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); int pickBackendSocketForSending(std::shared_ptr& state); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); + +void carbonDumpThread(); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 42b16a3de5..83a579f983 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -140,7 +140,7 @@ dnsdist_SOURCES = \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ dnsdist-healthchecks.cc dnsdist-healthchecks.hh \ - dnsdist-idstate.cc \ + dnsdist-idstate.cc dnsdist-idstate.hh \ dnsdist-kvs.hh dnsdist-kvs.cc \ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \ dnsdist-lua-actions.cc \ @@ -160,6 +160,7 @@ dnsdist_SOURCES = \ dnsdist-lua.cc dnsdist-lua.hh \ dnsdist-prometheus.hh \ dnsdist-protobuf.cc dnsdist-protobuf.hh \ + dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rules.cc dnsdist-rules.hh \ @@ -168,7 +169,7 @@ dnsdist_SOURCES = \ dnsdist-systemd.cc dnsdist-systemd.hh \ dnsdist-tcp-downstream.cc dnsdist-tcp-downstream.hh \ dnsdist-tcp-upstream.hh \ - dnsdist-tcp.cc \ + dnsdist-tcp.cc dnsdist-tcp.hh \ dnsdist-web.cc dnsdist-web.hh \ dnsdist-xpf.cc dnsdist-xpf.hh \ dnsdist.cc dnsdist.hh \ @@ -229,7 +230,7 @@ testrunner_SOURCES = \ dnsdist-dynblocks.cc dnsdist-dynblocks.hh \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ - dnsdist-idstate.cc \ + dnsdist-idstate.cc dnsdist-idstate.hh \ dnsdist-kvs.cc dnsdist-kvs.hh \ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \ dnsdist-lua-bindings-dnsquestion.cc \ @@ -238,11 +239,12 @@ testrunner_SOURCES = \ dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \ dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \ dnsdist-lua-vars.cc \ + dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-tcp-downstream.cc \ - dnsdist-tcp.cc \ + dnsdist-tcp.cc dnsdist-tcp.hh \ dnsdist-xpf.cc dnsdist-xpf.hh \ dnsdist.hh \ dnslabeltext.cc \ diff --git a/pdns/dnsdistdist/dnsdist-idstate.hh b/pdns/dnsdistdist/dnsdist-idstate.hh new file mode 120000 index 0000000000..44f6de4345 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-idstate.hh @@ -0,0 +1 @@ +../dnsdist-idstate.hh \ No newline at end of file diff --git a/pdns/dnsdistdist/dnsdist-protocols.cc b/pdns/dnsdistdist/dnsdist-protocols.cc new file mode 100644 index 0000000000..233bf4b876 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-protocols.cc @@ -0,0 +1,31 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsdist-protocols.hh" + +namespace dnsdist { + const std::string& ProtocolToString(Protocol proto) + { + static const std::vector values = { "Do53 UDP", "Do53 TCP", "DNSCrypt UDP", "DNSCrypt TCP", "DNS over TLS", "DNS over HTTPS" }; + return values.at(static_cast(proto)); + } +} + diff --git a/pdns/dnsdistdist/dnsdist-protocols.hh b/pdns/dnsdistdist/dnsdist-protocols.hh new file mode 120000 index 0000000000..cb9d2fd79c --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-protocols.hh @@ -0,0 +1 @@ +../dnsdist-protocols.hh \ No newline at end of file diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 7a8b28cd6e..9623866699 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -4,37 +4,14 @@ #include "dnsparser.hh" -const uint16_t TCPConnectionToBackend::s_xfrID = 0; - -void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr& clientConn, bool isXFR) -{ - if (d_usedForXFR == true) { - throw std::runtime_error("Trying to send a query over a backend connection used for XFR"); - } - - if (isXFR) { - d_usedForXFR = true; - } - - if (!d_clientConn) { - d_clientConn = clientConn; - d_ioState = make_unique(clientConn->getIOMPlexer(), d_handler->getDescriptor()); - } - else if (d_clientConn != clientConn) { - throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries"); - } -} - void TCPConnectionToBackend::release() { - if (!d_usedForXFR) { - d_ds->outstanding -= d_pendingResponses.size(); - } + d_ds->outstanding -= d_pendingResponses.size(); d_pendingResponses.clear(); d_pendingQueries.clear(); - d_clientConn.reset(); + d_sender.reset(); if (d_ioState) { d_ioState.reset(); } @@ -72,9 +49,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptrd_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery); conn->d_currentQuery.d_buffer.clear(); - if (!conn->d_usedForXFR) { - ++conn->d_ds->outstanding; - } + ++conn->d_ds->outstanding; return state; } @@ -185,20 +160,34 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c DEBUGLOG("connection died, number of failures is "<d_downstreamFailures<<", retries is "<d_ds->retries); - if ((!conn->d_usedForXFR || conn->d_queries == 0) && conn->d_downstreamFailures < conn->d_ds->retries) { + if (conn->d_downstreamFailures < conn->d_ds->retries) { conn->d_ioState.reset(); ioGuard.release(); try { if (conn->reconnect()) { - conn->d_ioState = make_unique(conn->d_clientConn->getIOMPlexer(), conn->d_handler->getDescriptor()); + conn->d_ioState = make_unique(conn->d_mplexer, conn->d_handler->getDescriptor()); /* we need to resend the queries that were in flight, if any */ for (auto& pending : conn->d_pendingResponses) { - conn->d_pendingQueries.push_back(std::move(pending.second)); - if (!conn->d_usedForXFR) { - --conn->d_ds->outstanding; + --conn->d_ds->outstanding; + + if (pending.second.isXFR() && pending.second.d_xfrStarted) { + /* this one can't be restarted, sorry */ + DEBUGLOG("A XFR for which a response has already been sent cannot be restarted"); + try { + conn->d_sender->notifyIOError(std::move(pending.second.d_idstate), now); + } + catch (const std::exception& e) { + vinfolog("Got an exception while notifying: %s", e.what()); + } + catch (...) { + vinfolog("Got exception while notifying"); + } + } + else { + conn->d_pendingQueries.push_back(std::move(pending.second)); } } conn->d_pendingResponses.clear(); @@ -278,10 +267,14 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t handleIO(conn, now); } -void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr& sharedSelf) +void TCPConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) { - if (d_ioState == nullptr) { - throw std::runtime_error("Trying to queue a query to a TCP connection that has no incoming client connection assigned"); + if (!d_sender) { + d_sender = sender; + d_ioState = make_unique(d_mplexer, d_handler->getDescriptor()); + } + else if (d_sender != sender) { + throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries"); } // if we are not already sending a query or in the middle of reading a response (so idle or doingHandshake), @@ -299,7 +292,8 @@ void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptractive()) { + auto& sender = d_sender; + if (!sender->active()) { // a client timeout occurred, or something like that */ - d_clientConn.reset(); + d_sender.reset(); return; } if (reason == FailureReason::timeout) { - ++clientConn->d_ci.cs->tcpDownstreamTimeouts; + ++sender->getClientState().tcpDownstreamTimeouts; } else if (reason == FailureReason::gaveUp) { - ++clientConn->d_ci.cs->tcpGaveUp; + ++sender->getClientState().tcpGaveUp; } try { if (d_state == State::sendingQueryToBackend) { - clientConn->notifyIOError(clientConn, std::move(d_currentQuery.d_idstate), now); + sender->notifyIOError(std::move(d_currentQuery.d_idstate), now); } for (auto& query : d_pendingQueries) { - clientConn->notifyIOError(clientConn, std::move(query.d_idstate), now); + sender->notifyIOError(std::move(query.d_idstate), now); } for (auto& response : d_pendingResponses) { - clientConn->notifyIOError(clientConn, std::move(response.second.d_idstate), now); + sender->notifyIOError(std::move(response.second.d_idstate), now); } } catch (const std::exception& e) { @@ -467,8 +461,8 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptractive()) { + auto& sender = d_sender; + if (!sender || !sender->active()) { // a client timeout occurred, or something like that */ d_connectionDied = true; @@ -494,11 +488,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrd_usedForXFR) { - --conn->d_ds->outstanding; - } - - if (d_usedForXFR) { + if (it->second.isXFR()) { DEBUGLOG("XFR!"); bool done = false; TCPResponse response; @@ -509,22 +499,22 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrsecond.d_idstate.qname; DEBUGLOG("passing XFRresponse to client connection for "<second.d_xfrStarted = true; + done = isXFRFinished(response, it->second); if (done) { d_pendingResponses.erase(it); + --conn->d_ds->outstanding; /* marking as idle for now, so we can accept new queries if our queues are empty */ if (d_pendingQueries.empty() && d_pendingResponses.empty()) { d_state = State::idle; } - clientConn->d_isXFR = false; - conn->d_usedForXFR = false; } - clientConn->handleXFRResponse(clientConn, now, std::move(response)); + sender->handleXFRResponse(now, std::move(response)); if (done) { d_state = State::idle; - d_clientConn.reset(); + d_sender.reset(); return IOState::Done; } @@ -534,6 +524,9 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrd_ds->outstanding; + } auto ids = std::move(it->second.d_idstate); d_pendingResponses.erase(it); @@ -543,7 +536,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrhandleResponse(clientConn, now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); + sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); if (!d_pendingQueries.empty()) { DEBUGLOG("still have some queries to send"); @@ -563,7 +556,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr& clientConn) +bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, TCPQuery& query) { bool done = false; try { @@ -626,20 +619,20 @@ bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, const sh auto raw = unknownContent->getRawContent(); auto serial = getSerialFromRawSOAContent(raw); - ++clientConn->d_xfrSerialCount; - if (clientConn->d_xfrMasterSerial == 0) { + ++query.d_xfrSerialCount; + if (query.d_xfrMasterSerial == 0) { // store the first SOA in our client's connection metadata - ++clientConn->d_xfrMasterSerialCount; - clientConn->d_xfrMasterSerial = serial; + ++query.d_xfrMasterSerialCount; + query.d_xfrMasterSerial = serial; } - else if (clientConn->d_xfrMasterSerial == serial) { - ++clientConn->d_xfrMasterSerialCount; + else if (query.d_xfrMasterSerial == serial) { + ++query.d_xfrMasterSerialCount; // figure out if it's end when receiving master's SOA again - if (clientConn->d_xfrSerialCount == 2) { + if (query.d_xfrSerialCount == 2) { // if there are only two SOA records marks a finished AXFR done = true; } - if (clientConn->d_xfrMasterSerialCount == 3) { + if (query.d_xfrMasterSerialCount == 3) { // receiving master's SOA 3 times marks a finished IXFR done = true; } diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh index f9d26e5180..228bee4b3b 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -5,49 +5,12 @@ #include "sstuff.hh" #include "tcpiohandler-mplexer.hh" #include "dnsdist.hh" +#include "dnsdist-tcp.hh" -struct TCPQuery -{ - TCPQuery() - { - } - - TCPQuery(PacketBuffer&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer)) - { - } - - IDState d_idstate; - PacketBuffer d_buffer; - std::string d_proxyProtocolPayload; - bool d_proxyProtocolPayloadAdded{false}; -}; - -class TCPConnectionToBackend; - -struct TCPResponse : public TCPQuery -{ - TCPResponse() - { - /* let's make Coverity happy */ - memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); - } - - TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr conn): TCPQuery(std::move(buffer), std::move(state)), d_connection(conn) - { - memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); - } - - std::shared_ptr d_connection{nullptr}; - dnsheader d_cleartextDH; - bool d_selfGenerated{false}; -}; - -class IncomingTCPConnectionState; - -class TCPConnectionToBackend +class TCPConnectionToBackend : public std::enable_shared_from_this { public: - TCPConnectionToBackend(std::shared_ptr& ds, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_ds(ds), d_connectionStartTime(now), d_lastDataReceivedTime(now), d_enableFastOpen(ds->tcpFastOpen) + TCPConnectionToBackend(std::shared_ptr& ds, std::unique_ptr& mplexer, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_mplexer(mplexer), d_ds(ds), d_connectionStartTime(now), d_lastDataReceivedTime(now), d_enableFastOpen(ds->tcpFastOpen) { reconnect(); } @@ -64,8 +27,6 @@ public: } } - void assignToClientConnection(std::shared_ptr& clientConn, bool isXFR); - int getHandle() const { if (!d_handler) { @@ -118,10 +79,8 @@ public: /* whether we can accept new queries FOR THE SAME CLIENT */ bool canAcceptNewQueries() const { - if (d_usedForXFR || d_connectionDied) { + if (d_connectionDied) { return false; - /* Don't reuse the TCP connection after an {A,I}XFR */ - /* but don't reset it either, we will need to read more messages */ } if ((d_pendingQueries.size() + d_pendingResponses.size()) >= d_ds->d_maxInFlightQueriesPerConn) { @@ -139,7 +98,7 @@ public: /* whether a connection can be reused for a different client */ bool canBeReused() const { - if (d_usedForXFR || d_connectionDied) { + if (d_connectionDied) { return false; } /* we can't reuse a connection where a proxy protocol payload has been sent, @@ -163,7 +122,7 @@ public: return ds == d_ds; } - void queueQuery(TCPQuery&& query, std::shared_ptr& sharedSelf); + void queueQuery(std::shared_ptr& sender, TCPQuery&& query); void handleTimeout(const struct timeval& now, bool write); void release(); @@ -177,7 +136,7 @@ public: std::string toString() const { ostringstream o; - o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<& conn); static IOState sendQuery(std::shared_ptr& conn, const struct timeval& now); - static bool isXFRFinished(const TCPResponse& response, const shared_ptr& clientConn); + static bool isXFRFinished(const TCPResponse& response, TCPQuery& query); IOState handleResponse(std::shared_ptr& conn, const struct timeval& now); uint16_t getQueryIdFromResponse(); @@ -247,16 +206,15 @@ private: return res; } - static const uint16_t s_xfrID; - PacketBuffer d_responseBuffer; std::deque d_pendingQueries; std::unordered_map d_pendingResponses; + std::unique_ptr& d_mplexer; std::unique_ptr> d_proxyProtocolValuesSent{nullptr}; std::unique_ptr d_handler{nullptr}; std::unique_ptr d_ioState{nullptr}; std::shared_ptr d_ds{nullptr}; - std::shared_ptr d_clientConn; + std::shared_ptr d_sender{nullptr}; TCPQuery d_currentQuery; struct timeval d_connectionStartTime; struct timeval d_lastDataReceivedTime; @@ -268,6 +226,5 @@ private: bool d_fresh{true}; bool d_enableFastOpen{false}; bool d_connectionDied{false}; - bool d_usedForXFR{false}; bool d_proxyProtocolPayloadSent{false}; }; diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 5eed088486..7db91dca5e 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -1,6 +1,7 @@ #pragma once #include "dolog.hh" +#include "dnsdist-tcp.hh" class TCPClientThreadData { @@ -14,48 +15,7 @@ public: std::unique_ptr mplexer{nullptr}; }; -struct ConnectionInfo -{ - ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1) - { - } - ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd) - { - rhs.cs = nullptr; - rhs.fd = -1; - } - - ConnectionInfo(const ConnectionInfo& rhs) = delete; - ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete; - - ConnectionInfo& operator=(ConnectionInfo&& rhs) - { - remote = rhs.remote; - cs = rhs.cs; - rhs.cs = nullptr; - fd = rhs.fd; - rhs.fd = -1; - return *this; - } - - ~ConnectionInfo() - { - if (fd != -1) { - close(fd); - fd = -1; - } - - if (cs) { - --cs->tcpCurrentConnections; - } - } - - ComboAddress remote; - ClientState* cs{nullptr}; - int fd{-1}; -}; - -class IncomingTCPConnectionState +class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this { public: IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(make_unique(threadData.mplexer, d_ci.fd)), d_connectionStartTime(now) @@ -145,34 +105,35 @@ public: std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now); void registerActiveDownstreamConnection(std::shared_ptr& conn); - std::unique_ptr& getIOMPlexer() const - { - return d_threadData.mplexer; - } - static size_t clearAllDownstreamConnections(); static void handleIO(std::shared_ptr& conn, const struct timeval& now); static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); - static void notifyIOError(std::shared_ptr& state, IDState&& query, const struct timeval& now); + static IOState sendResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); static void queueResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); +static void handleTimeout(std::shared_ptr& state, bool write); /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */ - static void handleResponse(std::shared_ptr state, const struct timeval& now, TCPResponse&& response); - static void handleXFRResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); - static void handleTimeout(std::shared_ptr& state, bool write); + void handleResponse(const struct timeval& now, TCPResponse&& response) override; + void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override; + void notifyIOError(IDState&& query, const struct timeval& now) override; void terminateClientConnection(); void queueQuery(TCPQuery&& query); bool canAcceptNewQueries(const struct timeval& now); - bool active() const + bool active() const override { return d_ioState != nullptr; } + const ClientState& getClientState() override + { + return *d_ci.cs; + } + std::string toString() const { ostringstream o; @@ -203,9 +164,6 @@ public: size_t d_proxyProtocolNeed{0}; size_t d_queriesCount{0}; size_t d_currentQueriesCount{0}; - uint32_t d_xfrMasterSerial{0}; - uint32_t d_xfrSerialCount{0}; - uint8_t d_xfrMasterSerialCount{0}; uint16_t d_querySize{0}; State d_state{State::doingHandshake}; bool d_isXFR{false}; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh new file mode 100644 index 0000000000..b932e5d5a4 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -0,0 +1,296 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +struct ConnectionInfo +{ + ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1) + { + } + ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd) + { + rhs.cs = nullptr; + rhs.fd = -1; + } + + ConnectionInfo(const ConnectionInfo& rhs) = delete; + ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete; + + ConnectionInfo& operator=(ConnectionInfo&& rhs) + { + remote = rhs.remote; + cs = rhs.cs; + rhs.cs = nullptr; + fd = rhs.fd; + rhs.fd = -1; + return *this; + } + + ~ConnectionInfo() + { + if (fd != -1) { + close(fd); + fd = -1; + } + + if (cs) { + --cs->tcpCurrentConnections; + } + } + + ComboAddress remote; + ClientState* cs{nullptr}; + int fd{-1}; +}; + +struct InternalQuery +{ + InternalQuery() + { + } + + InternalQuery(PacketBuffer&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer)) + { + } + + InternalQuery(InternalQuery&& rhs) : + d_idstate(std::move(rhs.d_idstate)), d_buffer(std::move(rhs.d_buffer)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded) + { + } + InternalQuery& operator=(InternalQuery&& rhs) + { + d_idstate = std::move(rhs.d_idstate); + d_buffer = std::move(rhs.d_buffer); + d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload); + d_xfrMasterSerial = rhs.d_xfrMasterSerial; + d_xfrSerialCount = rhs.d_xfrSerialCount; + d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount; + d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded; + return *this; + } + + InternalQuery(const InternalQuery& rhs) = delete; + InternalQuery& operator=(const InternalQuery& rhs) = delete; + + bool isXFR() const + { + return d_idstate.qtype == QType::AXFR || d_idstate.qtype == QType::IXFR; + } + + IDState d_idstate; + PacketBuffer d_buffer; + std::string d_proxyProtocolPayload; + uint32_t d_xfrMasterSerial{0}; + uint32_t d_xfrSerialCount{0}; + uint8_t d_xfrMasterSerialCount{0}; + bool d_xfrStarted{false}; + bool d_proxyProtocolPayloadAdded{false}; +}; + +using TCPQuery = InternalQuery; + +class TCPConnectionToBackend; + +struct TCPResponse : public TCPQuery +{ + TCPResponse() + { + /* let's make Coverity happy */ + memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); + } + + TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr conn): TCPQuery(std::move(buffer), std::move(state)), d_connection(conn) + { + memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); + } + + std::shared_ptr d_connection{nullptr}; + dnsheader d_cleartextDH; + bool d_selfGenerated{false}; +}; + +class TCPQuerySender +{ +public: + virtual ~TCPQuerySender() + { + } + + virtual bool active() const = 0; + virtual const ClientState& getClientState() = 0; + virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0; + virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0; + virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0; +}; + +struct CrossProtocolQuery +{ + CrossProtocolQuery() + { + } + + CrossProtocolQuery(CrossProtocolQuery&& rhs) = delete; + virtual ~CrossProtocolQuery() + { + } + + virtual std::shared_ptr getTCPQuerySender() = 0; + + InternalQuery query; + std::shared_ptr downstream{nullptr}; +}; + +class TCPClientCollection { +public: + TCPClientCollection(size_t maxThreads); + + int getThread() + { + if (d_numthreads == 0) { + throw std::runtime_error("No TCP worker thread yet"); + } + + uint64_t pos = d_pos++; + ++d_queued; + return d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe; + } + + bool passConnectionToThread(std::unique_ptr&& conn) + { + if (d_numthreads == 0) { + throw std::runtime_error("No TCP worker thread yet"); + } + + uint64_t pos = d_pos++; + auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe; + auto tmp = conn.release(); + + if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { + delete tmp; + tmp = nullptr; + return false; + } + ++d_queued; + return true; + } + + bool passCrossProtocolQueryToThread(std::unique_ptr&& cpq) + { + if (d_numthreads == 0) { + throw std::runtime_error("No TCP worker thread yet"); + } + + uint64_t pos = d_pos++; + auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueryPipe; + auto tmp = cpq.release(); + + if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { + delete tmp; + tmp = nullptr; + return false; + } + + return true; + } + + bool hasReachedMaxThreads() const + { + return d_numthreads >= d_maxthreads; + } + + uint64_t getThreadsCount() const + { + return d_numthreads; + } + + uint64_t getQueuedCount() const + { + return d_queued; + } + + void decrementQueuedCount() + { + --d_queued; + } + + void addTCPClientThread(); + +private: + struct TCPWorkerThread + { + TCPWorkerThread() + { + } + + TCPWorkerThread(int newConnPipe, int crossProtocolPipe): d_newConnectionPipe(newConnPipe), d_crossProtocolQueryPipe(crossProtocolPipe) + { + } + + TCPWorkerThread(TCPWorkerThread&& rhs): d_newConnectionPipe(rhs.d_newConnectionPipe), d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe) + { + rhs.d_newConnectionPipe = -1; + rhs.d_crossProtocolQueryPipe = -1; + } + + TCPWorkerThread& operator=(TCPWorkerThread&& rhs) + { + if (d_newConnectionPipe != -1) { + close(d_newConnectionPipe); + } + if (d_crossProtocolQueryPipe != -1) { + close(d_crossProtocolQueryPipe); + } + + d_newConnectionPipe = rhs.d_newConnectionPipe; + d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe; + rhs.d_newConnectionPipe = -1; + rhs.d_crossProtocolQueryPipe = -1; + + return *this; + } + + TCPWorkerThread(const TCPWorkerThread& rhs) = delete; + TCPWorkerThread& operator=(const TCPWorkerThread&) = delete; + + ~TCPWorkerThread() + { + if (d_newConnectionPipe != -1) { + close(d_newConnectionPipe); + } + if (d_crossProtocolQueryPipe != -1) { + close(d_crossProtocolQueryPipe); + } + } + + int d_newConnectionPipe{-1}; + int d_crossProtocolQueryPipe{-1}; + }; + + std::mutex d_mutex; + std::vector d_tcpclientthreads; + stat_t d_numthreads{0}; + stat_t d_pos{0}; + stat_t d_queued{0}; + const uint64_t d_maxthreads{0}; +}; + +extern std::unique_ptr g_tcpclientthreads; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index e89859c3e8..296c09e6db 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -20,6 +20,7 @@ #include "dnsname.hh" #undef CERT #include "dnsdist.hh" +#include "dnsdist-tcp.hh" #include "misc.hh" #include "dns.hh" #include "dolog.hh" @@ -177,6 +178,11 @@ struct DOHServerConfig dohquerypair[0] = fd[1]; dohquerypair[1] = fd[0]; + setNonBlocking(dohquerypair[0]); + if (internalPipeBufferSize > 0) { + setPipeBufferSize(dohquerypair[0], internalPipeBufferSize); + } + if (pipe(fd) < 0) { close(dohquerypair[0]); close(dohquerypair[1]); @@ -186,11 +192,6 @@ struct DOHServerConfig dohresponsepair[0] = fd[1]; dohresponsepair[1] = fd[0]; - setNonBlocking(dohquerypair[0]); - if (internalPipeBufferSize > 0) { - setPipeBufferSize(dohquerypair[0], internalPipeBufferSize); - } - setNonBlocking(dohresponsepair[0]); if (internalPipeBufferSize > 0) { setPipeBufferSize(dohresponsepair[0], internalPipeBufferSize); @@ -198,6 +199,14 @@ struct DOHServerConfig setNonBlocking(dohresponsepair[1]); + if (pipe(fd) < 0) { + close(dohquerypair[0]); + close(dohquerypair[1]); + close(dohresponsepair[0]); + close(dohresponsepair[1]); + unixDie("Creating a pipe for DNS over HTTPS"); + } + h2o_config_init(&h2o_config); h2o_config.http2.idle_timeout = idleTimeout * 1000; } @@ -465,13 +474,12 @@ static int processDOHQuery(DOHUnit* du) uint16_t qtype, qclass; unsigned int qnameWireLength = 0; DNSName qname(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - DNSQuestion dq(&qname, qtype, qclass, &du->dest, &du->remote, du->query, DNSQuestion::Protocol::DoH, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, &du->dest, &du->remote, du->query, dnsdist::Protocol::DoH, &queryRealTime); dq.ednsAdded = du->ednsAdded; dq.du = du; dq.sni = std::move(du->sni); - std::shared_ptr ss{nullptr}; - auto result = processQuery(dq, cs, holders, ss); + auto result = processQuery(dq, cs, holders, du->downstream); if (result == ProcessQueryResult::Drop) { du->status_code = 403; @@ -493,14 +501,14 @@ static int processDOHQuery(DOHUnit* du) return -1; } - if (ss == nullptr) { + if (du->downstream == nullptr) { du->status_code = 502; return -1; } ComboAddress dest = du->dest; - unsigned int idOffset = (ss->idOffset++) % ss->idStates.size(); - IDState* ids = &ss->idStates[idOffset]; + unsigned int idOffset = (du->downstream->idOffset++) % du->downstream->idStates.size(); + IDState* ids = &du->downstream->idStates[idOffset]; ids->age = 0; DOHUnit* oldDU = nullptr; if (ids->isInUse()) { @@ -516,13 +524,13 @@ static int processDOHQuery(DOHUnit* du) /* the state was not in use. we reset 'oldDU' because it might have still been in use when we read it. */ oldDU = nullptr; - ++ss->outstanding; + ++du->downstream->outstanding; } else { ids->du = nullptr; /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need to handle it because it's about to be overwritten. */ - ++ss->reuseds; + ++du->downstream->reuseds; ++g_stats.downstreamTimeouts; handleDOHTimeout(oldDU); } @@ -554,16 +562,16 @@ static int processDOHQuery(DOHUnit* du) ids->destHarvested = false; } - if (ss->useProxyProtocol) { + if (du->downstream->useProxyProtocol) { addProxyProtocol(dq); } - int fd = pickBackendSocketForSending(ss); + int fd = pickBackendSocketForSending(du->downstream); try { /* you can't touch du after this line, because it might already have been freed */ - ssize_t ret = udpClientSendRequestToBackend(ss, fd, du->query); + ssize_t ret = udpClientSendRequestToBackend(du->downstream, fd, du->query); - if(ret < 0) { + if (ret < 0) { /* we are about to handle the error, make sure that this pointer is not accessed when the state is cleaned, but first check that it still belongs to us */ @@ -571,9 +579,9 @@ static int processDOHQuery(DOHUnit* du) ids->du = nullptr; du->release(); duRefCountIncremented = false; - --ss->outstanding; + --du->downstream->outstanding; } - ++ss->sendErrors; + ++du->downstream->sendErrors; ++g_stats.downstreamSendErrors; du->status_code = 502; return -1; @@ -586,7 +594,7 @@ static int processDOHQuery(DOHUnit* du) throw; } - vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), ss->getName()); + vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), du->downstream->getName()); } catch(const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); @@ -1120,6 +1128,94 @@ static void dnsdistclient(int qsock) } } +class DoHTCPCrossQuerySender : public TCPQuerySender +{ +public: + DoHTCPCrossQuerySender(DOHUnit* du_): du(du_) + { + } + + ~DoHTCPCrossQuerySender() + { + if (du != nullptr) { + du->release(); + } + } + + bool active() const override + { + return true; + } + + const ClientState& getClientState() override + { + if (!du || !du->dsc || !du->dsc->cs) { + throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender"); + } + + return *du->dsc->cs; + } + + void handleResponse(const struct timeval& now, TCPResponse&& response) override + { + if (!du) { + return; + } + + if (du->rsock == -1) { + return; + } + + du->response = std::move(response.d_buffer); + + auto sent = write(du->rsock, &du, sizeof(du)); + if (sent != sizeof(du)) { + du->release(); + du = nullptr; + } + } + + void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override + { + throw std::runtime_error("Oops"); + } + + void notifyIOError(IDState&& query, const struct timeval& now) override + { + throw std::runtime_error("Oops"); + } + +private: + DOHUnit* du{nullptr}; +}; + +class DoHCrossProtocolQuery : public CrossProtocolQuery +{ +public: + DoHCrossProtocolQuery(DOHUnit* du_): du(du_) + { + query = InternalQuery(std::move(du->query), std::move(du->ids)); + downstream = du->downstream; + } + + ~DoHCrossProtocolQuery() + { + if (du != nullptr) { + du->release(); + } + } + + std::shared_ptr getTCPQuerySender() override + { + auto sender = std::make_shared(du); + du = nullptr; + return sender; + } + +private: + DOHUnit* du{nullptr}; +}; + /* Called in the main DoH thread if h2o finds that dnsdist gave us an answer by writing into the dohresponsepair[0] side of the pipe so from: - handleDOHTimeout() when we did not get a response fast enough (called @@ -1147,6 +1243,28 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) return; } + if (!du->response.empty() && !du->tcp) { + const dnsheader* dh = reinterpret_cast(du->response.data()); + + if (dh->tc) { + /* restoring the original ID */ + dnsheader* queryDH = reinterpret_cast(du->query.data()); + queryDH->id = htons(du->ids.origID); + + auto cpq = std::make_unique(du); + + du->get(); + du->tcp = true; + + if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { + return; + } + else { + du->release(); + } + } + } + if (du->self) { // we are back in the h2o main thread now, so we don't risk // a race (h2o killing the query) when accessing du->req anymore @@ -1452,6 +1570,32 @@ void dohThread(ClientState* cs) } } +void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state) +{ + static_assert(sizeof(*this) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + + response = std::move(udpResponse); + ids = std::move(state); + + auto du = this; + ssize_t sent = write(rsock, &du, sizeof(du)); + if (sent != sizeof(this)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + ++g_stats.dohResponsePipeFull; + vinfolog("Unable to pass a DoH response to the DoH worker thread because the pipe is full"); + } + else { + vinfolog("Unable to pass a DoH response to the DoH worker thread because we couldn't write to the pipe: %s", stringerror()); + } + + /* at this point we have the only remaining pointer on this + DOHUnit object since we did set ids->du to nullptr earlier, + except if we got the response before the pointer could be + released by the frontend */ + release(); + } +} + #else /* HAVE_DNS_OVER_HTTPS */ void handleDOHTimeout(DOHUnit* oldDU) diff --git a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc index 316fd78206..bc013bd089 100644 --- a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc @@ -307,7 +307,7 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { ComboAddress lc("192.0.2.1:53"); ComboAddress rem("192.0.2.128:42"); PacketBuffer packet(sizeof(dnsheader)); - auto proto = DNSQuestion::Protocol::DoUDP; + auto proto = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; @@ -387,7 +387,7 @@ BOOST_AUTO_TEST_CASE(test_CDB) { ComboAddress lc("192.0.2.1:53"); ComboAddress rem("192.0.2.128:42"); PacketBuffer packet(sizeof(dnsheader)); - auto proto = DNSQuestion::Protocol::DoUDP; + auto proto = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index b1a31886e3..d1f5080836 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -100,7 +100,7 @@ static DNSQuestion getDQ(const DNSName* providedName = nullptr) uint16_t qtype = QType::A; uint16_t qclass = QClass::IN; - auto proto = DNSQuestion::Protocol::DoUDP; + auto proto = dnsdist::Protocol::DoUDP; gettime(&queryRealTime, true); DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); diff --git a/pdns/dnsdistdist/test-dnsdistrules_cc.cc b/pdns/dnsdistdist/test-dnsdistrules_cc.cc index 6ae2a6e14b..70ccbd0020 100644 --- a/pdns/dnsdistdist/test-dnsdistrules_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrules_cc.cc @@ -23,7 +23,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { ComboAddress lc("127.0.0.1:53"); ComboAddress rem("192.0.2.1:42"); PacketBuffer packet(sizeof(dnsheader)); - auto proto = DNSQuestion::Protocol::DoUDP; + auto proto = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 4f506fefc1..c1a1398ee8 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -2647,6 +2647,11 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) /* the backend descriptor becomes ready */ dynamic_cast(threadData.mplexer.get())->setReady(desc); } }, + /* no more query from the client for now */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 , [&threadData](int desc, const ExpectedStep& step) { + /* the client descriptor becomes NOT ready */ + dynamic_cast(threadData.mplexer.get())->setNotReady(-1); + } }, /* read the response (1) from the backend */ { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, 2 }, { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, axfrResponses.at(0).size() - 2 }, diff --git a/pdns/doh.hh b/pdns/doh.hh index 58e4cfba8f..9a2d12bb4d 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -173,7 +173,10 @@ struct DOHUnit #else /* HAVE_DNS_OVER_HTTPS */ #include +#include "dnsdist-idstate.hh" + struct st_h2o_req_t; +struct DownstreamState; struct DOHUnit { @@ -199,9 +202,12 @@ struct DOHUnit } } + void handleUDPResponse(PacketBuffer&& response, IDState&& state); + std::vector> headers; PacketBuffer query; PacketBuffer response; + IDState ids; std::string sni; std::string path; std::string scheme; @@ -211,6 +217,7 @@ struct DOHUnit st_h2o_req_t* req{nullptr}; DOHUnit** self{nullptr}; DOHServerConfig* dsc{nullptr}; + std::shared_ptr downstream{nullptr}; std::string contentType; std::atomic d_refcnt{1}; size_t query_at{0}; @@ -224,6 +231,9 @@ struct DOHUnit */ uint16_t status_code{200}; bool ednsAdded{false}; + /* whether the query was re-sent to the backend over + TCP after receiving a truncated answer over UDP */ + bool tcp{false}; std::string getHTTPPath() const; std::string getHTTPHost() const; diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 2466297be9..2d49ad2439 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -62,7 +62,7 @@ static void validateECS(const PacketBuffer& packet, const ComboAddress& expected uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast(packet), DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast(packet), dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(parseEDNSOptions(dq)); BOOST_REQUIRE(dq.ednsOptions != nullptr); BOOST_CHECK_EQUAL(dq.ednsOptions->size(), 1U); @@ -113,7 +113,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_CHECK(addXPF(dq, xpfOptionCode)); BOOST_CHECK(packet.size() > query.size()); @@ -132,7 +132,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_REQUIRE(!addXPF(dq, xpfOptionCode)); BOOST_CHECK_EQUAL(packet.size(), 4096U); @@ -150,7 +150,7 @@ BOOST_AUTO_TEST_CASE(test_addXPF) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); /* add trailing data */ const size_t trailingDataSize = 10; @@ -337,7 +337,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(!parseEDNSOptions(dq)); @@ -360,7 +360,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); BOOST_CHECK_GT(packet.size(), query.size()); @@ -439,7 +439,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(parseEDNSOptions(dq)); @@ -461,7 +461,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); BOOST_CHECK_GT(packet.size(), query.size()); @@ -537,7 +537,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); dq.ecsOverride = true; /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ @@ -1430,7 +1430,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) { static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, PacketBuffer& query) { - return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, DNSQuestion::Protocol::DoUDP, &realTime); + return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, dnsdist::Protocol::DoUDP, &realTime); } static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, PacketBuffer& query, bool resizeBuffer=true) @@ -1933,7 +1933,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { unsigned int consumed = 0; uint16_t qtype; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); BOOST_CHECK(packet.size() > query.size()); @@ -1957,7 +1957,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { unsigned int consumed = 0; uint16_t qtype; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -1985,7 +1985,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { unsigned int consumed = 0; uint16_t qtype; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); BOOST_CHECK(packet.size() > query.size()); @@ -2009,7 +2009,7 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { unsigned int consumed = 0; uint16_t qtype; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -2050,7 +2050,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2071,7 +2071,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2092,7 +2092,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, DNSQuestion::Protocol::DoUDP, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 40759684df..7af31fa110 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -81,7 +81,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); if (found == true) { auto removed = PC.expungeByName(a); @@ -100,7 +100,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) { matches++; } @@ -161,7 +161,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -192,7 +192,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) { matches++; } @@ -257,7 +257,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { /* UDP */ uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -272,7 +272,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { /* same but over TCP */ uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoTCP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoTCP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, !receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -316,7 +316,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -369,7 +369,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -421,7 +421,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -470,7 +470,7 @@ static void threadMangler(unsigned int offset) uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); @@ -500,7 +500,7 @@ static void threadReader(unsigned int offset) uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); if (!found) { g_missing++; @@ -576,7 +576,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnetOut, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_REQUIRE(subnetOut); @@ -619,7 +619,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &secondKey, subnetOut, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK_EQUAL(secondKey, key); @@ -695,7 +695,7 @@ BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) { ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, DNSQuestion::Protocol::DoUDP, &queryTime); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); bool found = PC.get(dq, 0, &key, subnetOut, true, receivedOverUDP); BOOST_CHECK_EQUAL(found, false);