]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist.cc
dnsdist: Increase the default value of setMaxUDPOutstanding to 65535
[thirdparty/pdns.git] / pdns / dnsdist.cc
index babcc9972878f14790ffd9da249722fd7001c79c..e9ca2aff0d4d326fb034c8f04a947fc6283d7624 100644 (file)
@@ -83,7 +83,7 @@ bool g_verbose;
 struct DNSDistStats g_stats;
 MetricDefinitionStorage g_metricDefinitions;
 
-uint16_t g_maxOutstanding{10240};
+uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};
 bool g_verboseHealthChecks{false};
 uint32_t g_staleCacheEntriesTTL{0};
 bool g_syslog{true};
@@ -544,30 +544,50 @@ try {
         }
 
         IDState* ids = &dss->idStates[queryId];
-        int origFD = ids->origFD;
+        int64_t usageIndicator = ids->usageIndicator;
 
-        if(origFD < 0 && ids->du == nullptr) // duplicate
+        if(!IDState::isInUse(usageIndicator)) {
+          /* the corresponding state is marked as not in use, meaning that:
+             - it was already cleaned up by another thread and the state is gone ;
+             - we already got a response for this query and this one is a duplicate.
+             Either way, we don't touch it.
+          */
           continue;
+        }
 
+        /* read the potential DOHUnit state as soon as possible, but don't use it
+           until we have confirmed that we own this state by updating usageIndicator */
+        auto du = ids->du;
         /* setting age to 0 to prevent the maintainer thread from
            cleaning this IDS while we process the response.
-           We have already a copy of the origFD, so it would
-           mostly mess up the outstanding counter.
         */
         ids->age = 0;
+        int origFD = ids->origFD;
 
         unsigned int consumed = 0;
         if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote, consumed)) {
           continue;
         }
 
-        int oldFD = ids->origFD.exchange(-1);
-        if (oldFD == origFD) {
+        bool isDoH = du != nullptr;
+        /* atomically mark the state as available, but only if it has not been altered
+           in the meantime */
+        if (ids->tryMarkUnused(usageIndicator)) {
+          /* clear the potential DOHUnit asap, it's ours now
+           and since we just marked the state as unused,
+           someone could overwrite it. */
+          ids->du = nullptr;
           /* we only decrement the outstanding counter if the value was not
              altered in the meantime, which would mean that the state has been actively reused
              and the other thread has not incremented the outstanding counter, so we don't
              want it to be decremented twice. */
           --dss->outstanding;  // you'd think an attacker could game this, but we're using connected socket
+        } else {
+          /* someone updated the state in the meantime, we can't touch the existing pointer */
+          du = nullptr;
+          /* since the state has been updated, we can't safely access it so let's just drop
+             this response */
+          continue;
         }
 
         if(dh->tc && g_truncateTC) {
@@ -588,15 +608,17 @@ try {
         }
 
         if (ids->cs && !ids->cs->muted) {
-          if (ids->du) {
+          if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
             // DoH query
-            ids->du->query = std::string(response, responseLen);
-            if (send(ids->du->rsock, &ids->du, sizeof(ids->du), 0) != sizeof(ids->du)) {
-              delete ids->du;
+            du->response = std::string(response, responseLen);
+            if (send(du->rsock, &du, sizeof(du), 0) != sizeof(du)) {
+              /* at this point we have the only remaining pointer on this
+                 DOHUnit object since we did set ids->du to nullptr earlier */
+              delete du;
             }
 #endif /* HAVE_DNS_OVER_HTTPS */
-            ids->du = nullptr;
+            du = nullptr;
           }
           else {
             ComboAddress empty;
@@ -611,7 +633,7 @@ try {
 
         double udiff = ids->sentTime.udiff();
         vinfolog("Got answer from %s, relayed to %s%s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(),
-                 ids->du ? " (https)": "", udiff);
+                 isDoH ? " (https)": "", udiff);
 
         struct timespec ts;
         gettime(&ts);
@@ -1013,12 +1035,76 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
   }
 }
 
-static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, const struct timespec& now)
+bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop)
+{
+  switch(action) {
+  case DNSAction::Action::Allow:
+    return true;
+    break;
+  case DNSAction::Action::Drop:
+    ++g_stats.ruleDrop;
+    drop = true;
+    return true;
+    break;
+  case DNSAction::Action::Nxdomain:
+    dq.dh->rcode = RCode::NXDomain;
+    dq.dh->qr=true;
+    ++g_stats.ruleNXDomain;
+    return true;
+    break;
+  case DNSAction::Action::Refused:
+    dq.dh->rcode = RCode::Refused;
+    dq.dh->qr=true;
+    ++g_stats.ruleRefused;
+    return true;
+    break;
+  case DNSAction::Action::ServFail:
+    dq.dh->rcode = RCode::ServFail;
+    dq.dh->qr=true;
+    ++g_stats.ruleServFail;
+    return true;
+    break;
+  case DNSAction::Action::Spoof:
+    spoofResponseFromString(dq, ruleresult);
+    return true;
+    break;
+  case DNSAction::Action::Truncate:
+    dq.dh->tc = true;
+    dq.dh->qr = true;
+    return true;
+    break;
+  case DNSAction::Action::HeaderModify:
+    return true;
+    break;
+  case DNSAction::Action::Pool:
+    dq.poolname=ruleresult;
+    return true;
+    break;
+  case DNSAction::Action::NoRecurse:
+    dq.dh->rd = false;
+    return true;
+    break;
+    /* non-terminal actions follow */
+  case DNSAction::Action::Delay:
+    dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+    break;
+  case DNSAction::Action::None:
+    /* fall-through */
+  case DNSAction::Action::NoOp:
+    break;
+  }
+
+  /* false means that we don't stop the processing */
+  return false;
+}
+
+
+static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now)
 {
   g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
 
   if(g_qcount.enabled) {
-    string qname = (*dq.qname).toString(".");
+    string qname = (*dq.qname).toLogString();
     bool countQuery{true};
     if(g_qcount.filter) {
       std::lock_guard<std::mutex> lock(g_luamutex);
@@ -1075,7 +1161,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
           return true;
         }
         else {
-          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         }
         break;
       case DNSAction::Action::NoRecurse:
@@ -1107,14 +1193,14 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         /* do nothing */
         break;
       case DNSAction::Action::Nxdomain:
-        vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         updateBlockStats();
 
         dq.dh->rcode = RCode::NXDomain;
         dq.dh->qr=true;
         return true;
       case DNSAction::Action::Refused:
-        vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         updateBlockStats();
 
         dq.dh->rcode = RCode::Refused;
@@ -1124,13 +1210,13 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         if(!dq.tcp) {
           updateBlockStats();
       
-          vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
           dq.dh->tc = true;
           dq.dh->qr = true;
           return true;
         }
         else {
-          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         }
         break;
       case DNSAction::Action::NoRecurse:
@@ -1140,7 +1226,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         return true;
       default:
         updateBlockStats();
-        vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         return false;
       }
     }
@@ -1148,69 +1234,21 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
 
   DNSAction::Action action=DNSAction::Action::None;
   string ruleresult;
+  bool drop = false;
   for(const auto& lr : *holders.rulactions) {
     if(lr.d_rule->matches(&dq)) {
       lr.d_rule->d_matches++;
       action=(*lr.d_action)(&dq, &ruleresult);
-
-      switch(action) {
-      case DNSAction::Action::Allow:
-        return true;
-        break;
-      case DNSAction::Action::Drop:
-        ++g_stats.ruleDrop;
-        return false;
-        break;
-      case DNSAction::Action::Nxdomain:
-        dq.dh->rcode = RCode::NXDomain;
-        dq.dh->qr=true;
-        ++g_stats.ruleNXDomain;
-        return true;
-        break;
-      case DNSAction::Action::Refused:
-        dq.dh->rcode = RCode::Refused;
-        dq.dh->qr=true;
-        ++g_stats.ruleRefused;
-        return true;
-        break;
-      case DNSAction::Action::ServFail:
-        dq.dh->rcode = RCode::ServFail;
-        dq.dh->qr=true;
-        ++g_stats.ruleServFail;
-        return true;
-        break;
-      case DNSAction::Action::Spoof:
-        spoofResponseFromString(dq, ruleresult);
-        return true;
-        break;
-      case DNSAction::Action::Truncate:
-        dq.dh->tc = true;
-        dq.dh->qr = true;
-        return true;
-        break;
-      case DNSAction::Action::HeaderModify:
-        return true;
-        break;
-      case DNSAction::Action::Pool:
-        poolname=ruleresult;
-        return true;
-        break;
-        /* non-terminal actions follow */
-      case DNSAction::Action::Delay:
-        dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
-        break;
-      case DNSAction::Action::None:
-        /* fall-through */
-      case DNSAction::Action::NoOp:
-        break;
-      case DNSAction::Action::NoRecurse:
-        dq.dh->rd = false;
-        return true;
+      if (processRulesResult(action, dq, ruleresult, drop)) {
         break;
       }
     }
   }
 
+  if (drop) {
+    return false;
+  }
+
   return true;
 }
 
@@ -1392,9 +1430,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
     struct timespec now;
     gettime(&now);
 
-    string poolname;
-
-    if (!applyRulesToQuery(holders, dq, poolname, now)) {
+    if (!applyRulesToQuery(holders, dq, now)) {
       return ProcessQueryResult::Drop;
     }
 
@@ -1409,7 +1445,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       return ProcessQueryResult::SendAnswer;
     }
 
-    std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
+    std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.poolname);
     dq.packetCache = serverPool->packetCache;
     auto policy = *(holders.policy);
     if (serverPool->policy != nullptr) {
@@ -1472,7 +1508,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
     if(!selectedBackend) {
       ++g_stats.noPolicy;
 
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
       if (g_servFailOnNoPolicy) {
         restoreFlags(dq.dh, dq.origFlags);
 
@@ -1564,19 +1600,33 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
     ids->age = 0;
-    ids->du = nullptr;
+    DOHUnit* du = nullptr;
+
+    /* that means that the state was in use, possibly with an allocated
+       DOHUnit that we will need to handle, but we can't touch it before
+       confirming that we now own this state */
+    if (ids->isInUse()) {
+      du = ids->du;
+    }
 
-    int oldFD = ids->origFD.exchange(cs.udpFD);
-    if(oldFD < 0) {
-      // if we are reusing, no change in outstanding
+    /* we atomically replace the value, we now own this state */
+    if (!ids->markAsUsed()) {
+      /* the state was not in use.
+         we reset 'du' because it might have still been in use when we read it. */
+      du = nullptr;
       ++ss->outstanding;
     }
     else {
+      /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
+         to handle it because it's about to be overwritten. */
+      ids->du = nullptr;
       ++ss->reuseds;
       ++g_stats.downstreamTimeouts;
+      handleDOHTimeout(du);
     }
 
     ids->cs = &cs;
+    ids->origFD = cs.udpFD;
     ids->origID = dh->id;
     setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
@@ -1605,7 +1655,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ++g_stats.downstreamSendErrors;
     }
 
-    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
+    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
   }
   catch(const std::exception& e){
     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
@@ -2054,29 +2104,28 @@ static void healthChecksThread()
       dss->prev.reuseds.store(dss->reuseds.load());
       
       for(IDState& ids  : dss->idStates) { // timeouts
-        int origFD = ids.origFD;
-        if(origFD >=0 && ids.age++ > g_udpTimeout) {
-          /* We set origFD to -1 as soon as possible
+        int64_t usageIndicator = ids.usageIndicator;
+        if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) {
+          /* We mark the state as unused as soon as possible
              to limit the risk of racing with the
              responder thread.
-             The UDP client thread only checks origFD to
-             know whether outstanding has to be incremented,
-             so the sooner the better any way since we _will_
-             decrement it.
           */
-          if (ids.origFD.exchange(-1) != origFD) {
+          auto oldDU = ids.du;
+
+          if (!ids.tryMarkUnused(usageIndicator)) {
             /* this state has been altered in the meantime,
                don't go anywhere near it */
             continue;
           }
           ids.du = nullptr;
+          handleDOHTimeout(oldDU);
           ids.age = 0;
           dss->reuseds++;
           --dss->outstanding;
           ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
           vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
                    dss->remote.toStringWithPort(), dss->name,
-                   ids.qname.toString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
+                   ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
 
           struct timespec ts;
           gettime(&ts);