IDState(const IDState& orig) = delete;
IDState(IDState&& rhs)
{
- if (rhs.isInUse()) {
- throw std::runtime_error("Trying to move an in-use IDState");
- }
-
-#ifdef __SANITIZE_THREAD__
+ inUse.store(rhs.inUse.load());
age.store(rhs.age.load());
-#else
- age = rhs.age;
-#endif
internal = std::move(rhs.internal);
}
IDState& operator=(IDState&& rhs)
{
- if (isInUse()) {
- throw std::runtime_error("Trying to overwrite an in-use IDState");
- }
-
- if (rhs.isInUse()) {
- throw std::runtime_error("Trying to move an in-use IDState");
- }
-#ifdef __SANITIZE_THREAD__
+ inUse.store(rhs.inUse.load());
age.store(rhs.age.load());
-#else
- age = rhs.age;
-#endif
-
internal = std::move(rhs.internal);
-
return *this;
}
- static const int64_t unusedIndicator = -1;
-
- static bool isInUse(int64_t usageIndicator)
- {
- return usageIndicator != unusedIndicator;
- }
-
bool isInUse() const
{
- return usageIndicator != unusedIndicator;
- }
-
- /* return true if the value has been successfully replaced meaning that
- no-one updated the usage indicator in the meantime */
- bool tryMarkUnused(int64_t expectedUsageIndicator)
- {
- return usageIndicator.compare_exchange_strong(expectedUsageIndicator, unusedIndicator);
- }
-
- /* mark as used no matter what, return true if the state was in use before */
- bool markAsUsed()
- {
- auto currentGeneration = generation++;
- return markAsUsed(currentGeneration);
- }
-
- /* mark as used no matter what, return true if the state was in use before */
- bool markAsUsed(int64_t currentGeneration)
- {
- int64_t oldUsage = usageIndicator.exchange(currentGeneration);
- return oldUsage != unusedIndicator;
+ return inUse == true;
}
- /* We use this value to detect whether this state is in use.
- For performance reasons we don't want to use a lock here, but that means
+ /* For performance reasons we don't want to use a lock here, but that means
we need to be very careful when modifying this value. Modifications happen
from:
- one of the UDP or DoH 'client' threads receiving a query, selecting a backend
the corresponding state and sending the response to the client ;
- the 'healthcheck' thread scanning the states to actively discover timeouts,
mostly to keep some counters like the 'outstanding' one sane.
- We previously based that logic on the origFD (FD on which the query was received,
- and therefore from where the response should be sent) but this suffered from an
- ABA problem since it was quite likely that a UDP 'client thread' would reset it to the
- same value since we only have so much incoming sockets:
- - 1/ 'client' thread gets a query and set origFD to its FD, say 5 ;
- - 2/ 'receiver' thread gets a response, read the value of origFD to 5, check that the qname,
- qtype and qclass match
- - 3/ during that time the 'client' thread reuses the state, setting again origFD to 5 ;
- - 4/ the 'receiver' thread uses compare_exchange_strong() to only replace the value if it's still
- 5, except it's not the same 5 anymore and it overrides a fresh state.
- We now use a 32-bit unsigned counter instead, which is incremented every time the state is set,
- wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1
- when the state is unused and the value of our counter otherwise.
+
+ We have two flags:
+ - inUse tells us if there currently is a in-flight query whose state is stored
+ in this state
+ - locked tells us whether someone currently owns the state, so no-one else can touch
+ it
*/
InternalQueryState internal;
- std::atomic<int64_t> usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8
- std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4
-#ifdef __SANITIZE_THREAD__
std::atomic<uint16_t> age{0};
-#else
- uint16_t age{0}; // 2
-#endif
+
+ class StateGuard
+ {
+ public:
+ StateGuard(IDState& ids) :
+ d_ids(ids)
+ {
+ }
+ ~StateGuard()
+ {
+ d_ids.release();
+ }
+ StateGuard(const StateGuard&) = delete;
+ StateGuard(StateGuard&&) = delete;
+ StateGuard& operator=(const StateGuard&) = delete;
+ StateGuard& operator=(StateGuard&&) = delete;
+
+ private:
+ IDState& d_ids;
+ };
+
+ [[nodiscard]] std::optional<StateGuard> acquire()
+ {
+ bool expected = false;
+ if (locked.compare_exchange_strong(expected, true)) {
+ return std::optional<StateGuard>(*this);
+ }
+ return std::nullopt;
+ }
+
+ void release()
+ {
+ locked.store(false);
+ }
+
+ std::atomic<bool> inUse{false}; // 1
+
+private:
+ std::atomic<bool> locked{false}; // 1
};
doLatencyStats(incomingProtocol, udiff);
}
-static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated, std::optional<uint16_t> queryId)
+static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated)
{
DNSResponse dr(ids, response, ds);
memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
- if (queryId) {
- ds->releaseState(*queryId);
- }
return;
}
else {
handleResponseSent(ids, 0., dr.ids.origRemote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP);
}
-
- if (queryId) {
- ds->releaseState(*queryId);
- }
}
// listens on a dedicated socket, lobs answers from downstream servers to original requestors
dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
queryId = dh->id;
- IDState* ids = dss->getExistingState(queryId);
- if (ids == nullptr) {
+ auto ids = dss->getState(queryId);
+ if (!ids) {
continue;
}
- int64_t usageIndicator = ids->usageIndicator;
-
- if (!IDState::isInUse(usageIndicator)) {
- /* the corresponding state is marked as not in use, meaning that:
- - it was already cleaned up by another thread and the state is gone ;
- - we already got a response for this query and this one is a duplicate.
- Either way, we don't touch it.
- */
- continue;
- }
-
- /* setting age to 0 to prevent the maintainer thread from
- cleaning this IDS while we process the response.
- */
- ids->age = 0;
-
unsigned int qnameWireLength = 0;
- if (fd != ids->internal.backendFD || !responseContentMatches(response, ids->internal.qname, ids->internal.qtype, ids->internal.qclass, dss, qnameWireLength)) {
+ if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
+ dss->restoreState(queryId, std::move(*ids));
continue;
}
- DOHUnitUniquePtr du(nullptr, DOHUnit::release);
- /* atomically mark the state as available, but only if it has not been altered
- in the meantime */
- if (ids->tryMarkUnused(usageIndicator)) {
- /* clear the potential DOHUnit asap, it's ours now
- and since we just marked the state as unused,
- someone could overwrite it. */
- du = std::move(ids->internal.du);
- /* we only decrement the outstanding counter if the value was not
- altered in the meantime, which would mean that the state has been actively reused
- and the other thread has not incremented the outstanding counter, so we don't
- want it to be decremented twice. */
- --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket
- } else {
- /* someone updated the state in the meantime, we can't touch the existing pointer */
- du.release();
- /* since the state has been updated, we can't safely access it so let's just drop
- this response */
- continue;
- }
+ auto du = std::move(ids->du);
- dh->id = ids->internal.origID;
+ dh->id = ids->origID;
++dss->responses;
- double udiff = ids->internal.queryRealTime.udiff();
+ double udiff = ids->queryRealTime.udiff();
// do that _before_ the processing, otherwise it's not fair to the backend
dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
dss->reportResponse(dh->rcode);
if (du) {
#ifdef HAVE_DNS_OVER_HTTPS
// DoH query, we cannot touch du after that
- handleUDPResponseForDoH(std::move(du), std::move(response), std::move(ids->internal));
+ handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids));
#endif
- dss->releaseState(queryId);
continue;
}
- handleResponseForUDPClient(ids->internal, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId);
+ handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false);
}
}
catch (const std::exception& e) {
static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
- handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt);
+ handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated);
}
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
}
};
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
{
bool doh = dq.ids.du != nullptr;
- unsigned int idOffset = 0;
- int64_t generation;
- IDState* ids = ds->getIDState(idOffset, generation);
-
- dq.getHeader()->id = idOffset;
bool failed = false;
+ size_t proxyPayloadSize = 0;
if (ds->d_config.useProxyProtocol) {
try {
- size_t payloadSize = 0;
- if (addProxyProtocol(dq, &payloadSize)) {
+ if (addProxyProtocol(dq, &proxyPayloadSize)) {
if (dq.ids.du) {
- dq.ids.du->proxyProtocolPayloadSize = payloadSize;
+ dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize;
}
}
}
catch (const std::exception& e) {
vinfolog("Adding proxy protocol payload to %squery from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what());
- failed = true;
+ return false;
}
}
try {
- if (!failed) {
- int fd = ds->pickSocketForSending();
- dq.ids.backendFD = fd;
- dq.ids.origID = queryID;
- dq.ids.forwardedOverUDP = true;
- ids->internal = std::move(dq.ids);
+ int fd = ds->pickSocketForSending();
+ dq.ids.backendFD = fd;
+ dq.ids.origID = queryID;
+ dq.ids.forwardedOverUDP = true;
- vinfolog("Got query for %s|%s from %s%s, relayed to %s", ids->internal.qname.toLogString(), QType(ids->internal.qtype).toString(), ids->internal.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());
- /* you can't touch du after this line, unless the call returned a non-negative value,
- because it might already have been freed */
- ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
+ vinfolog("Got query for %s|%s from %s%s, relayed to %s", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr());
- if (ret < 0) {
- failed = true;
- }
- }
- else {
- ids->internal = std::move(dq.ids);
+ auto idOffset = ds->saveState(std::move(dq.ids));
+ /* set the correct ID */
+ memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset));
+
+ /* you can't touch ids or du after this line, unless the call returned a non-negative value,
+ because it might already have been freed */
+ ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
+
+ if (ret < 0) {
+ failed = true;
}
if (failed) {
- /* we are about to handle the error, make sure that
- this pointer is not accessed when the state is cleaned,
- but first check that it still belongs to us */
- if (ids->tryMarkUnused(generation) && ids->internal.du) {
- dq.ids.du = std::move(ids->internal.du);
- --ds->outstanding;
- }
- if (dq.ids.du) {
- dq.ids.du->status_code = 502;
+ /* clear up the state. In the very unlikely event it was reused
+ in the meantime, so be it. */
+ auto cleared = ds->getState(idOffset);
+ if (cleared) {
+ dq.ids.du = std::move(cleared->du);
+ if (dq.ids.du) {
+ dq.ids.du->status_code = 502;
+ }
}
++g_stats.downstreamSendErrors;
++ds->sendErrors;
return;
}
- assignOutgoingUDPQueryToBackend(ss, dh->id, dq, std::move(query), dest);
+ assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest);
}
catch(const std::exception& e){
vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what());
int pickSocketForSending();
void pickSocketsReadyForReceiving(std::vector<int>& ready);
void handleUDPTimeouts();
- IDState* getIDState(unsigned int& id, int64_t& generation);
- IDState* getExistingState(unsigned int id);
- void releaseState(unsigned int id);
void reportTimeoutOrError();
void reportResponse(uint8_t rcode);
void submitHealthCheckResult(bool initial, bool newState);
time_t getNextLazyHealthCheck();
+ uint16_t saveState(InternalQueryState&&);
+ void restoreState(uint16_t id, InternalQueryState&&);
+ std::optional<InternalQueryState> getState(uint16_t id);
dnsdist::Protocol getProtocol() const
{
enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend };
ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest);
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest);
ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol);
bool DownstreamState::s_randomizeIDs{false};
int DownstreamState::s_udpTimeout{2};
-static bool isIDSExpired(IDState& ids)
+static bool isIDSExpired(const IDState& ids)
{
- auto age = ids.age++;
+ auto age = ids.age.load();
return age > DownstreamState::s_udpTimeout;
}
void DownstreamState::handleUDPTimeout(IDState& ids)
{
- /* We mark the state as unused as soon as possible
- to limit the risk of racing with the
- responder thread.
- */
ids.age = 0;
+ ids.inUse = false;
handleDOHTimeout(std::move(ids.internal.du));
reuseds++;
--outstanding;
it = map->erase(it);
continue;
}
+ ++ids.age;
++it;
}
}
else {
if (outstanding.load() > 0) {
for (IDState& ids : idStates) {
- int64_t usageIndicator = ids.usageIndicator;
- if (IDState::isInUse(usageIndicator) && isIDSExpired(ids)) {
- if (!ids.tryMarkUnused(usageIndicator)) {
- /* this state has been altered in the meantime,
- don't go anywhere near it */
- continue;
- }
-
+ if (!ids.isInUse()) {
+ continue;
+ }
+ if (!isIDSExpired(ids)) {
+ ++ids.age;
+ continue;
+ }
+ auto guard = ids.acquire();
+ if (!guard) {
+ continue;
+ }
+ /* check again, now that we have locked this state */
+ if (ids.isInUse() && isIDSExpired(ids)) {
handleUDPTimeout(ids);
}
}
}
}
-IDState* DownstreamState::getExistingState(unsigned int stateId)
+uint16_t DownstreamState::saveState(InternalQueryState&& state)
{
if (s_randomizeIDs) {
+ /* if the state is already in use we will retry,
+ up to 5 five times. The last selected one is used
+ even if it was already in use */
+ size_t remainingAttempts = 5;
auto map = d_idStatesMap.lock();
- auto it = map->find(stateId);
- if (it == map->end()) {
- return nullptr;
+
+ do {
+ uint16_t selectedID = dnsdist::getRandomValue(std::numeric_limits<uint16_t>::max());
+ auto [it, inserted] = map->emplace(selectedID, IDState());
+
+ if (!inserted) {
+ remainingAttempts--;
+ if (remainingAttempts > 0) {
+ continue;
+ }
+
+ auto oldDU = std::move(it->second.internal.du);
+ ++reuseds;
+ ++g_stats.downstreamTimeouts;
+ handleDOHTimeout(std::move(oldDU));
+ }
+ else {
+ ++outstanding;
+ }
+
+ it->second.internal = std::move(state);
+ it->second.age.store(0);
+
+ return it->first;
}
- return &it->second;
+ while (true);
}
- else {
- if (stateId >= idStates.size()) {
- return nullptr;
+
+ do {
+ IDState* ids = nullptr;
+ uint16_t selectedID = (idOffset++) % idStates.size();
+ ids = &idStates[selectedID];
+ auto guard = ids->acquire();
+ if (!guard) {
+ continue;
+ }
+ if (ids->isInUse()) {
+ /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
+ to handle it because it's about to be overwritten. */
+ auto oldDU = std::move(ids->internal.du);
+ ++reuseds;
+ ++g_stats.downstreamTimeouts;
+ handleDOHTimeout(std::move(oldDU));
}
- return &idStates[stateId];
+ else {
+ ++outstanding;
+ }
+ ids->internal = std::move(state);
+ ids->age.store(0);
+ ids->inUse = true;
+ return selectedID;
}
+ while (true);
}
-void DownstreamState::releaseState(unsigned int stateId)
+void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state)
{
if (s_randomizeIDs) {
auto map = d_idStatesMap.lock();
- auto it = map->find(stateId);
- if (it == map->end()) {
- return;
+
+ auto [it, inserted] = map->emplace(id, IDState());
+ if (!inserted) {
+ /* already used */
+ ++reuseds;
+ ++g_stats.downstreamTimeouts;
+ handleDOHTimeout(std::move(state.du));
}
- if (it->second.isInUse()) {
- return;
+ else {
+ it->second.internal = std::move(state);
+ ++outstanding;
}
- map->erase(it);
+ return;
}
+
+ auto& ids = idStates[id];
+ auto guard = ids.acquire();
+ if (!guard) {
+ /* already used */
+ ++reuseds;
+ ++g_stats.downstreamTimeouts;
+ handleDOHTimeout(std::move(state.du));
+ return;
+ }
+ if (ids.isInUse()) {
+ /* already used */
+ ++reuseds;
+ ++g_stats.downstreamTimeouts;
+ handleDOHTimeout(std::move(state.du));
+ return;
+ }
+ ids.internal = std::move(state);
+ ids.inUse = true;
+ ++outstanding;
}
-IDState* DownstreamState::getIDState(unsigned int& selectedID, int64_t& generation)
+std::optional<InternalQueryState> DownstreamState::getState(uint16_t id)
{
- DOHUnitUniquePtr du(nullptr, DOHUnit::release);
- IDState* ids = nullptr;
+ std::optional<InternalQueryState> result = std::nullopt;
+
if (s_randomizeIDs) {
- /* if the state is already in use we will retry,
- up to 5 five times. The last selected one is used
- even if it was already in use */
- size_t remainingAttempts = 5;
auto map = d_idStatesMap.lock();
- bool done = false;
- do {
- selectedID = dnsdist::getRandomValue(std::numeric_limits<uint16_t>::max());
- auto [it, inserted] = map->insert({selectedID, IDState()});
- ids = &it->second;
- if (inserted) {
- done = true;
- }
- else {
- remainingAttempts--;
- }
+ auto it = map->find(id);
+ if (it == map->end()) {
+ return result;
}
- while (!done && remainingAttempts > 0);
- }
- else {
- selectedID = (idOffset++) % idStates.size();
- ids = &idStates[selectedID];
- }
- ids->age = 0;
+ result = std::move(it->second.internal);
+ map->erase(it);
+ --outstanding;
+ return result;
+ }
- /* we atomically replace the value, we now own this state */
- generation = ids->generation++;
- if (!ids->markAsUsed(generation)) {
- /* the state was not in use.
- we reset 'du' because it might have still been in use when we read it. */
- du.release();
- ++outstanding;
+ if (id > idStates.size()) {
+ return result;
}
- else {
- /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
- to handle it because it's about to be overwritten. */
- auto oldDU = std::move(ids->internal.du);
- ++reuseds;
- ++g_stats.downstreamTimeouts;
- handleDOHTimeout(std::move(oldDU));
+
+ auto& ids = idStates[id];
+ auto guard = ids.acquire();
+ if (!guard) {
+ return result;
}
- return ids;
+ if (ids.isInUse()) {
+ result = std::move(ids.internal);
+ --outstanding;
+ }
+ ids.inUse = false;
+ return result;
}
bool DownstreamState::healthCheckRequired(std::optional<time_t> currentTime)
}
ComboAddress dest = dq.ids.origDest;
- if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, std::move(du->query), dest)) {
+ if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) {
sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
return;
}
return false;
}
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
{
return false;
}