From: Remi Gacogne Date: Wed, 28 Sep 2022 15:21:16 +0000 (+0200) Subject: dnsdist: Stronger guarantees against data race in the UDP path X-Git-Tag: dnsdist-1.8.0-rc1~123^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=49942e45892cb017c6bb44cb959df4321cb9a79e;p=thirdparty%2Fpdns.git dnsdist: Stronger guarantees against data race in the UDP path --- diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index 1932301ad0..1fb8116be1 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -162,73 +162,25 @@ struct IDState IDState(const IDState& orig) = delete; IDState(IDState&& rhs) { - if (rhs.isInUse()) { - throw std::runtime_error("Trying to move an in-use IDState"); - } - -#ifdef __SANITIZE_THREAD__ + inUse.store(rhs.inUse.load()); age.store(rhs.age.load()); -#else - age = rhs.age; -#endif internal = std::move(rhs.internal); } IDState& operator=(IDState&& rhs) { - if (isInUse()) { - throw std::runtime_error("Trying to overwrite an in-use IDState"); - } - - if (rhs.isInUse()) { - throw std::runtime_error("Trying to move an in-use IDState"); - } -#ifdef __SANITIZE_THREAD__ + inUse.store(rhs.inUse.load()); age.store(rhs.age.load()); -#else - age = rhs.age; -#endif - internal = std::move(rhs.internal); - return *this; } - static const int64_t unusedIndicator = -1; - - static bool isInUse(int64_t usageIndicator) - { - return usageIndicator != unusedIndicator; - } - bool isInUse() const { - return usageIndicator != unusedIndicator; - } - - /* return true if the value has been successfully replaced meaning that - no-one updated the usage indicator in the meantime */ - bool tryMarkUnused(int64_t expectedUsageIndicator) - { - return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator); - } - - /* mark as used no matter what, return true if the state was in use before */ - bool markAsUsed() - { - auto currentGeneration = generation++; - return markAsUsed(currentGeneration); - } - - /* mark as used no matter what, return true if the state was in use before */ - bool markAsUsed(int64_t currentGeneration) - { - int64_t oldUsage = usageIndicator.exchange(currentGeneration); - return oldUsage != unusedIndicator; + return inUse == true; } - /* We use this value to detect whether this state is in use. - For performance reasons we don't want to use a lock here, but that means + /* For performance reasons we don't want to use a lock here, but that means we need to be very careful when modifying this value. Modifications happen from: - one of the UDP or DoH 'client' threads receiving a query, selecting a backend @@ -246,26 +198,52 @@ struct IDState the corresponding state and sending the response to the client ; - the 'healthcheck' thread scanning the states to actively discover timeouts, mostly to keep some counters like the 'outstanding' one sane. - We previously based that logic on the origFD (FD on which the query was received, - and therefore from where the response should be sent) but this suffered from an - ABA problem since it was quite likely that a UDP 'client thread' would reset it to the - same value since we only have so much incoming sockets: - - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ; - - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname, - qtype and qclass match - - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ; - - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still - 5, except it's not the same 5 anymore and it overrides a fresh state. - We now use a 32-bit unsigned counter instead, which is incremented every time the state is set, - wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1 - when the state is unused and the value of our counter otherwise. + + We have two flags: + - inUse tells us if there currently is a in-flight query whose state is stored + in this state + - locked tells us whether someone currently owns the state, so no-one else can touch + it */ InternalQueryState internal; - std::atomic usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8 - std::atomic generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4 -#ifdef __SANITIZE_THREAD__ std::atomic age{0}; -#else - uint16_t age{0}; // 2 -#endif + + class StateGuard + { + public: + StateGuard(IDState& ids) : + d_ids(ids) + { + } + ~StateGuard() + { + d_ids.release(); + } + StateGuard(const StateGuard&) = delete; + StateGuard(StateGuard&&) = delete; + StateGuard& operator=(const StateGuard&) = delete; + StateGuard& operator=(StateGuard&&) = delete; + + private: + IDState& d_ids; + }; + + [[nodiscard]] std::optional acquire() + { + bool expected = false; + if (locked.compare_exchange_strong(expected, true)) { + return std::optional(*this); + } + return std::nullopt; + } + + void release() + { + locked.store(false); + } + + std::atomic inUse{false}; // 1 + +private: + std::atomic locked{false}; // 1 }; diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 6b751c281f..0b6aa516f3 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -634,7 +634,7 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, doLatencyStats(incomingProtocol, udiff); } -static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool selfGenerated, std::optional queryId) +static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool selfGenerated) { DNSResponse dr(ids, response, ds); @@ -653,9 +653,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { - if (queryId) { - ds->releaseState(*queryId); - } return; } @@ -686,10 +683,6 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re else { handleResponseSent(ids, 0., dr.ids.origRemote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP); } - - if (queryId) { - ds->releaseState(*queryId); - } } // listens on a dedicated socket, lobs answers from downstream servers to original requestors @@ -728,57 +721,23 @@ void responderThread(std::shared_ptr dss) dnsheader* dh = reinterpret_cast(response.data()); queryId = dh->id; - IDState* ids = dss->getExistingState(queryId); - if (ids == nullptr) { + auto ids = dss->getState(queryId); + if (!ids) { continue; } - int64_t usageIndicator = ids->usageIndicator; - - if (!IDState::isInUse(usageIndicator)) { - /* the corresponding state is marked as not in use, meaning that: - - it was already cleaned up by another thread and the state is gone ; - - we already got a response for this query and this one is a duplicate. - Either way, we don't touch it. - */ - continue; - } - - /* setting age to 0 to prevent the maintainer thread from - cleaning this IDS while we process the response. - */ - ids->age = 0; - unsigned int qnameWireLength = 0; - if (fd != ids->internal.backendFD || !responseContentMatches(response, ids->internal.qname, ids->internal.qtype, ids->internal.qclass, dss, qnameWireLength)) { + if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) { + dss->restoreState(queryId, std::move(*ids)); continue; } - DOHUnitUniquePtr du(nullptr, DOHUnit::release); - /* atomically mark the state as available, but only if it has not been altered - in the meantime */ - if (ids->tryMarkUnused(usageIndicator)) { - /* clear the potential DOHUnit asap, it's ours now - and since we just marked the state as unused, - someone could overwrite it. */ - du = std::move(ids->internal.du); - /* we only decrement the outstanding counter if the value was not - altered in the meantime, which would mean that the state has been actively reused - and the other thread has not incremented the outstanding counter, so we don't - want it to be decremented twice. */ - --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket - } else { - /* someone updated the state in the meantime, we can't touch the existing pointer */ - du.release(); - /* since the state has been updated, we can't safely access it so let's just drop - this response */ - continue; - } + auto du = std::move(ids->du); - dh->id = ids->internal.origID; + dh->id = ids->origID; ++dss->responses; - double udiff = ids->internal.queryRealTime.udiff(); + double udiff = ids->queryRealTime.udiff(); // do that _before_ the processing, otherwise it's not fair to the backend dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; dss->reportResponse(dh->rcode); @@ -787,13 +746,12 @@ void responderThread(std::shared_ptr dss) if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query, we cannot touch du after that - handleUDPResponseForDoH(std::move(du), std::move(response), std::move(ids->internal)); + handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids)); #endif - dss->releaseState(queryId); continue; } - handleResponseForUDPClient(ids->internal, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId); + handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false); } } catch (const std::exception& e) { @@ -1445,7 +1403,7 @@ public: static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt); + handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated); } void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override @@ -1487,62 +1445,55 @@ public: } }; -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) { bool doh = dq.ids.du != nullptr; - unsigned int idOffset = 0; - int64_t generation; - IDState* ids = ds->getIDState(idOffset, generation); - - dq.getHeader()->id = idOffset; bool failed = false; + size_t proxyPayloadSize = 0; if (ds->d_config.useProxyProtocol) { try { - size_t payloadSize = 0; - if (addProxyProtocol(dq, &payloadSize)) { + if (addProxyProtocol(dq, &proxyPayloadSize)) { if (dq.ids.du) { - dq.ids.du->proxyProtocolPayloadSize = payloadSize; + dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize; } } } catch (const std::exception& e) { vinfolog("Adding proxy protocol payload to %squery from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what()); - failed = true; + return false; } } try { - if (!failed) { - int fd = ds->pickSocketForSending(); - dq.ids.backendFD = fd; - dq.ids.origID = queryID; - dq.ids.forwardedOverUDP = true; - ids->internal = std::move(dq.ids); + int fd = ds->pickSocketForSending(); + dq.ids.backendFD = fd; + dq.ids.origID = queryID; + dq.ids.forwardedOverUDP = true; - vinfolog("Got query for %s|%s from %s%s, relayed to %s", ids->internal.qname.toLogString(), QType(ids->internal.qtype).toString(), ids->internal.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr()); - /* you can't touch du after this line, unless the call returned a non-negative value, - because it might already have been freed */ - ssize_t ret = udpClientSendRequestToBackend(ds, fd, query); + vinfolog("Got query for %s|%s from %s%s, relayed to %s", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr()); - if (ret < 0) { - failed = true; - } - } - else { - ids->internal = std::move(dq.ids); + auto idOffset = ds->saveState(std::move(dq.ids)); + /* set the correct ID */ + memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset)); + + /* you can't touch ids or du after this line, unless the call returned a non-negative value, + because it might already have been freed */ + ssize_t ret = udpClientSendRequestToBackend(ds, fd, query); + + if (ret < 0) { + failed = true; } if (failed) { - /* we are about to handle the error, make sure that - this pointer is not accessed when the state is cleaned, - but first check that it still belongs to us */ - if (ids->tryMarkUnused(generation) && ids->internal.du) { - dq.ids.du = std::move(ids->internal.du); - --ds->outstanding; - } - if (dq.ids.du) { - dq.ids.du->status_code = 502; + /* clear up the state. In the very unlikely event it was reused + in the meantime, so be it. */ + auto cleared = ds->getState(idOffset); + if (cleared) { + dq.ids.du = std::move(cleared->du); + if (dq.ids.du) { + dq.ids.du->status_code = 502; + } } ++g_stats.downstreamSendErrors; ++ds->sendErrors; @@ -1667,7 +1618,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } - assignOutgoingUDPQueryToBackend(ss, dh->id, dq, std::move(query), dest); + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest); } catch(const std::exception& e){ vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what()); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 7c8a3ec364..a425018e6e 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -1002,13 +1002,13 @@ public: int pickSocketForSending(); void pickSocketsReadyForReceiving(std::vector& ready); void handleUDPTimeouts(); - IDState* getIDState(unsigned int& id, int64_t& generation); - IDState* getExistingState(unsigned int id); - void releaseState(unsigned int id); void reportTimeoutOrError(); void reportResponse(uint8_t rcode); void submitHealthCheckResult(bool initial, bool newState); time_t getNextLazyHealthCheck(); + uint16_t saveState(InternalQueryState&&); + void restoreState(uint16_t id, InternalQueryState&&); + std::optional getState(uint16_t id); dnsdist::Protocol getProtocol() const { @@ -1206,7 +1206,7 @@ static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses la enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend }; ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend); -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest); +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol); diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index daff23d6aa..eea141e440 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -321,19 +321,16 @@ bool DownstreamState::s_randomizeSockets{false}; bool DownstreamState::s_randomizeIDs{false}; int DownstreamState::s_udpTimeout{2}; -static bool isIDSExpired(IDState& ids) +static bool isIDSExpired(const IDState& ids) { - auto age = ids.age++; + auto age = ids.age.load(); return age > DownstreamState::s_udpTimeout; } void DownstreamState::handleUDPTimeout(IDState& ids) { - /* We mark the state as unused as soon as possible - to limit the risk of racing with the - responder thread. - */ ids.age = 0; + ids.inUse = false; handleDOHTimeout(std::move(ids.internal.du)); reuseds++; --outstanding; @@ -386,20 +383,26 @@ void DownstreamState::handleUDPTimeouts() it = map->erase(it); continue; } + ++ids.age; ++it; } } else { if (outstanding.load() > 0) { for (IDState& ids : idStates) { - int64_t usageIndicator = ids.usageIndicator; - if (IDState::isInUse(usageIndicator) && isIDSExpired(ids)) { - if (!ids.tryMarkUnused(usageIndicator)) { - /* this state has been altered in the meantime, - don't go anywhere near it */ - continue; - } - + if (!ids.isInUse()) { + continue; + } + if (!isIDSExpired(ids)) { + ++ids.age; + continue; + } + auto guard = ids.acquire(); + if (!guard) { + continue; + } + /* check again, now that we have locked this state */ + if (ids.isInUse() && isIDSExpired(ids)) { handleUDPTimeout(ids); } } @@ -407,89 +410,143 @@ void DownstreamState::handleUDPTimeouts() } } -IDState* DownstreamState::getExistingState(unsigned int stateId) +uint16_t DownstreamState::saveState(InternalQueryState&& state) { if (s_randomizeIDs) { + /* if the state is already in use we will retry, + up to 5 five times. The last selected one is used + even if it was already in use */ + size_t remainingAttempts = 5; auto map = d_idStatesMap.lock(); - auto it = map->find(stateId); - if (it == map->end()) { - return nullptr; + + do { + uint16_t selectedID = dnsdist::getRandomValue(std::numeric_limits::max()); + auto [it, inserted] = map->emplace(selectedID, IDState()); + + if (!inserted) { + remainingAttempts--; + if (remainingAttempts > 0) { + continue; + } + + auto oldDU = std::move(it->second.internal.du); + ++reuseds; + ++g_stats.downstreamTimeouts; + handleDOHTimeout(std::move(oldDU)); + } + else { + ++outstanding; + } + + it->second.internal = std::move(state); + it->second.age.store(0); + + return it->first; } - return &it->second; + while (true); } - else { - if (stateId >= idStates.size()) { - return nullptr; + + do { + IDState* ids = nullptr; + uint16_t selectedID = (idOffset++) % idStates.size(); + ids = &idStates[selectedID]; + auto guard = ids->acquire(); + if (!guard) { + continue; + } + if (ids->isInUse()) { + /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need + to handle it because it's about to be overwritten. */ + auto oldDU = std::move(ids->internal.du); + ++reuseds; + ++g_stats.downstreamTimeouts; + handleDOHTimeout(std::move(oldDU)); } - return &idStates[stateId]; + else { + ++outstanding; + } + ids->internal = std::move(state); + ids->age.store(0); + ids->inUse = true; + return selectedID; } + while (true); } -void DownstreamState::releaseState(unsigned int stateId) +void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state) { if (s_randomizeIDs) { auto map = d_idStatesMap.lock(); - auto it = map->find(stateId); - if (it == map->end()) { - return; + + auto [it, inserted] = map->emplace(id, IDState()); + if (!inserted) { + /* already used */ + ++reuseds; + ++g_stats.downstreamTimeouts; + handleDOHTimeout(std::move(state.du)); } - if (it->second.isInUse()) { - return; + else { + it->second.internal = std::move(state); + ++outstanding; } - map->erase(it); + return; } + + auto& ids = idStates[id]; + auto guard = ids.acquire(); + if (!guard) { + /* already used */ + ++reuseds; + ++g_stats.downstreamTimeouts; + handleDOHTimeout(std::move(state.du)); + return; + } + if (ids.isInUse()) { + /* already used */ + ++reuseds; + ++g_stats.downstreamTimeouts; + handleDOHTimeout(std::move(state.du)); + return; + } + ids.internal = std::move(state); + ids.inUse = true; + ++outstanding; } -IDState* DownstreamState::getIDState(unsigned int& selectedID, int64_t& generation) +std::optional DownstreamState::getState(uint16_t id) { - DOHUnitUniquePtr du(nullptr, DOHUnit::release); - IDState* ids = nullptr; + std::optional result = std::nullopt; + if (s_randomizeIDs) { - /* if the state is already in use we will retry, - up to 5 five times. The last selected one is used - even if it was already in use */ - size_t remainingAttempts = 5; auto map = d_idStatesMap.lock(); - bool done = false; - do { - selectedID = dnsdist::getRandomValue(std::numeric_limits::max()); - auto [it, inserted] = map->insert({selectedID, IDState()}); - ids = &it->second; - if (inserted) { - done = true; - } - else { - remainingAttempts--; - } + auto it = map->find(id); + if (it == map->end()) { + return result; } - while (!done && remainingAttempts > 0); - } - else { - selectedID = (idOffset++) % idStates.size(); - ids = &idStates[selectedID]; - } - ids->age = 0; + result = std::move(it->second.internal); + map->erase(it); + --outstanding; + return result; + } - /* we atomically replace the value, we now own this state */ - generation = ids->generation++; - if (!ids->markAsUsed(generation)) { - /* the state was not in use. - we reset 'du' because it might have still been in use when we read it. */ - du.release(); - ++outstanding; + if (id > idStates.size()) { + return result; } - else { - /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need - to handle it because it's about to be overwritten. */ - auto oldDU = std::move(ids->internal.du); - ++reuseds; - ++g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(oldDU)); + + auto& ids = idStates[id]; + auto guard = ids.acquire(); + if (!guard) { + return result; } - return ids; + if (ids.isInUse()) { + result = std::move(ids.internal); + --outstanding; + } + ids.inUse = false; + return result; } bool DownstreamState::healthCheckRequired(std::optional currentTime) diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index d6eb18bbc9..b66836b367 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -677,7 +677,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit) } ComboAddress dest = dq.ids.origDest; - if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, std::move(du->query), dest)) { + if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) { sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); return; } diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 71e140b99c..dacc245387 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -42,7 +42,7 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&) return false; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) { return false; }