From: Remi Gacogne Date: Thu, 5 Jul 2018 14:26:33 +0000 (+0200) Subject: dnsdist: Fix an outstanding counter race when reusing states X-Git-Tag: dnsdist-1.3.1~6^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F6773%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Fix an outstanding counter race when reusing states --- diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 9e8ad00555..0d1bf023ff 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -426,7 +426,6 @@ try { for(;;) { dnsheader* dh = reinterpret_cast(packet); - bool outstandingDecreased = false; try { pickBackendSocketsReadyForReceiving(dss, sockets); for (const auto& fd : sockets) { @@ -460,8 +459,14 @@ try { continue; } - --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket - outstandingDecreased = true; + int oldFD = ids->origFD.exchange(-1); + if (oldFD == origFD) { + /* we only decrement the outstanding counter if the value was not + altered in the meantime, which would mean that the state has been actively reused + and the other thread has not incremented the outstanding counter, so we don't + want it to be decremented twice. */ + --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket + } if(dh->tc && g_truncateTC) { truncateTC(response, &responseLen); @@ -522,12 +527,12 @@ try { doLatencyStats(udiff); - if (ids->origFD == origFD) { + /* if the FD is not -1, the state has been actively reused and we should + not alter anything */ + if (ids->origFD == -1) { #ifdef HAVE_DNSCRYPT ids->dnsCryptQuery = nullptr; #endif - ids->origFD = -1; - outstandingDecreased = false; } rewrittenResponse.clear(); @@ -535,14 +540,6 @@ try { } catch(const std::exception& e){ vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->remote.toStringWithPort(), queryId, e.what()); - if (outstandingDecreased) { - /* so an exception was raised after we decreased the outstanding queries counter, - but before we could set ids->origFD to -1 (because we also set outstandingDecreased - to false then), meaning the IDS is still considered active and we will decrease the - counter again on a duplicate, or simply while reaping downstream timeouts, so let's - increase it back. */ - dss->outstanding++; - } } } return 0; @@ -1435,15 +1432,17 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct IDState* ids = &ss->idStates[idOffset]; ids->age = 0; - if(ids->origFD < 0) // if we are reusing, no change in outstanding + int oldFD = ids->origFD.exchange(cs.udpFD); + if(oldFD < 0) { + // if we are reusing, no change in outstanding ss->outstanding++; + } else { ss->reuseds++; g_stats.downstreamTimeouts++; } ids->cs = &cs; - ids->origFD = cs.udpFD; ids->origID = dh->id; ids->origRemote = remote; ids->sentTime.set(queryRealTime); @@ -1874,7 +1873,8 @@ void* healthChecksThread() dss->prev.reuseds.store(dss->reuseds.load()); for(IDState& ids : dss->idStates) { // timeouts - if(ids.origFD >=0 && ids.age++ > g_udpTimeout) { + int origFD = ids.origFD; + if(origFD >=0 && ids.age++ > g_udpTimeout) { /* We set origFD to -1 as soon as possible to limit the risk of racing with the responder thread. @@ -1883,7 +1883,11 @@ void* healthChecksThread() so the sooner the better any way since we _will_ decrement it. */ - ids.origFD = -1; + if (ids.origFD.exchange(-1) != origFD) { + /* this state has been altered in the meantime, + don't go anywhere near it */ + continue; + } ids.age = 0; dss->reuseds++; --dss->outstanding; diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index cc30c09fd0..37fb8bde52 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -369,16 +369,15 @@ struct ClientState; struct IDState { IDState() : origFD(-1), sentTime(true), delayMsec(0), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;} - IDState(const IDState& orig): origRemote(orig.origRemote), origDest(orig.origDest) + IDState(const IDState& orig): origRemote(orig.origRemote), origDest(orig.origDest), age(orig.age) { - origFD = orig.origFD; + origFD.store(orig.origFD.load()); origID = orig.origID; delayMsec = orig.delayMsec; tempFailureTTL = orig.tempFailureTTL; - age.store(orig.age.load()); } - int origFD; // set to <0 to indicate this state is empty // 4 + std::atomic origFD; // set to <0 to indicate this state is empty // 4 ComboAddress origRemote; // 28 ComboAddress origDest; // 28 @@ -394,7 +393,7 @@ struct IDState std::shared_ptr qTag{nullptr}; const ClientState* cs{nullptr}; uint32_t cacheKey; // 8 - std::atomic age; // 4 + uint16_t age; // 4 uint16_t qtype; // 2 uint16_t qclass; // 2 uint16_t origID; // 2