]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist.cc
dnsdist: Add constants and helper functions to the IDState
[thirdparty/pdns.git] / pdns / dnsdist.cc
index 95b1d7f8f2689efce3add9eacb41d6b370225935..c2b63c08ab3b61cb5a2a327d31b32f63db744bf6 100644 (file)
@@ -200,6 +200,7 @@ void doLatencyStats(double udiff)
   else if(udiff < 100000) ++g_stats.latency50_100;
   else if(udiff < 1000000) ++g_stats.latency100_1000;
   else ++g_stats.latencySlow;
+  g_stats.latencySum += udiff / 1000;
 
   auto doAvg = [](double& var, double n, double weight) {
     var = (weight -1) * var/weight + n/weight;
@@ -543,30 +544,50 @@ try {
         }
 
         IDState* ids = &dss->idStates[queryId];
-        int origFD = ids->origFD;
+        int64_t usageIndicator = ids->usageIndicator;
 
-        if(origFD < 0 && ids->du == nullptr) // duplicate
+        if(!IDState::isInUse(usageIndicator)) {
+          /* the corresponding state is marked as not in use, meaning that:
+             - it was already cleaned up by another thread and the state is gone ;
+             - we already got a response for this query and this one is a duplicate.
+             Either way, we don't touch it.
+          */
           continue;
+        }
 
+        /* read the potential DOHUnit state as soon as possible, but don't use it
+           until we have confirmed that we own this state by updating usageIndicator */
+        auto du = ids->du;
         /* setting age to 0 to prevent the maintainer thread from
            cleaning this IDS while we process the response.
-           We have already a copy of the origFD, so it would
-           mostly mess up the outstanding counter.
         */
         ids->age = 0;
+        int origFD = ids->origFD;
 
         unsigned int consumed = 0;
         if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote, consumed)) {
           continue;
         }
 
-        int oldFD = ids->origFD.exchange(-1);
-        if (oldFD == origFD) {
+        bool isDoH = du != nullptr;
+        /* atomically mark the state as available, but only if it has not been altered
+           in the meantime */
+        if (ids->tryMarkUnused(usageIndicator)) {
+          /* clear the potential DOHUnit asap, it's ours now
+           and since we just marked the state as unused,
+           someone could overwrite it. */
+          ids->du = nullptr;
           /* we only decrement the outstanding counter if the value was not
              altered in the meantime, which would mean that the state has been actively reused
              and the other thread has not incremented the outstanding counter, so we don't
              want it to be decremented twice. */
           --dss->outstanding;  // you'd think an attacker could game this, but we're using connected socket
+        } else {
+          /* someone updated the state in the meantime, we can't touch the existing pointer */
+          du = nullptr;
+          /* since the state has been updated, we can't safely access it so let's just drop
+             this response */
+          continue;
         }
 
         if(dh->tc && g_truncateTC) {
@@ -587,15 +608,17 @@ try {
         }
 
         if (ids->cs && !ids->cs->muted) {
-          if (ids->du) {
+          if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
             // DoH query
-            ids->du->query = std::string(response, responseLen);
-            if (send(ids->du->rsock, &ids->du, sizeof(ids->du), 0) != sizeof(ids->du)) {
-              delete ids->du;
+            du->response = std::string(response, responseLen);
+            if (send(du->rsock, &du, sizeof(du), 0) != sizeof(du)) {
+              /* at this point we have the only remaining pointer on this
+                 DOHUnit object since we did set ids->du to nullptr earlier */
+              delete du;
             }
 #endif /* HAVE_DNS_OVER_HTTPS */
-            ids->du = nullptr;
+            du = nullptr;
           }
           else {
             ComboAddress empty;
@@ -610,7 +633,7 @@ try {
 
         double udiff = ids->sentTime.udiff();
         vinfolog("Got answer from %s, relayed to %s%s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(),
-                 ids->du ? " (https)": "", udiff);
+                 isDoH ? " (https)": "", udiff);
 
         struct timespec ts;
         gettime(&ts);
@@ -1017,7 +1040,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
   g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
 
   if(g_qcount.enabled) {
-    string qname = (*dq.qname).toString(".");
+    string qname = (*dq.qname).toLogString();
     bool countQuery{true};
     if(g_qcount.filter) {
       std::lock_guard<std::mutex> lock(g_luamutex);
@@ -1074,7 +1097,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
           return true;
         }
         else {
-          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         }
         break;
       case DNSAction::Action::NoRecurse:
@@ -1106,14 +1129,14 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         /* do nothing */
         break;
       case DNSAction::Action::Nxdomain:
-        vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         updateBlockStats();
 
         dq.dh->rcode = RCode::NXDomain;
         dq.dh->qr=true;
         return true;
       case DNSAction::Action::Refused:
-        vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         updateBlockStats();
 
         dq.dh->rcode = RCode::Refused;
@@ -1123,13 +1146,13 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         if(!dq.tcp) {
           updateBlockStats();
       
-          vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
           dq.dh->tc = true;
           dq.dh->qr = true;
           return true;
         }
         else {
-          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+          vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         }
         break;
       case DNSAction::Action::NoRecurse:
@@ -1139,7 +1162,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
         return true;
       default:
         updateBlockStats();
-        vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toString());
+        vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
         return false;
       }
     }
@@ -1223,10 +1246,10 @@ ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss
   else {
     struct msghdr msgh;
     struct iovec iov;
-    char cbuf[256];
+    cmsgbuf_aligned cbuf;
     ComboAddress remote(ss->remote);
-    fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast<char*>(request), requestLen, &remote);
-    addCMsgSrcAddr(&msgh, cbuf, &ss->sourceAddr, ss->sourceItf);
+    fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast<char*>(request), requestLen, &remote);
+    addCMsgSrcAddr(&msgh, &cbuf, &ss->sourceAddr, ss->sourceItf);
     result = sendmsg(sd, &msgh, 0);
   }
 
@@ -1320,7 +1343,7 @@ bool checkQueryHeaders(const struct dnsheader* dh)
 }
 
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, char* cbuf)
+static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, cmsgbuf_aligned* cbuf)
 {
   outMsg.msg_len = 0;
   fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast<char*>(response), responseLen, const_cast<ComboAddress*>(&remote));
@@ -1471,7 +1494,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
     if(!selectedBackend) {
       ++g_stats.noPolicy;
 
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
       if (g_servFailOnNoPolicy) {
         restoreFlags(dq.dh, dq.origFlags);
 
@@ -1501,7 +1524,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
   return ProcessQueryResult::Drop;
 }
 
-static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf)
 {
   assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
   uint16_t queryId = 0;
@@ -1563,19 +1586,33 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
     ids->age = 0;
-    ids->du = nullptr;
+    DOHUnit* du = nullptr;
+
+    /* that means that the state was in use, possibly with an allocated
+       DOHUnit that we will need to handle, but we can't touch it before
+       confirming that we now own this state */
+    if (ids->isInUse()) {
+      du = ids->du;
+    }
 
-    int oldFD = ids->origFD.exchange(cs.udpFD);
-    if(oldFD < 0) {
-      // if we are reusing, no change in outstanding
+    /* we atomically replace the value, we now own this state */
+    if (!ids->markAsUsed()) {
+      /* the state was not in use.
+         we reset 'du' because it might have still been in use when we read it. */
+      du = nullptr;
       ++ss->outstanding;
     }
     else {
+      /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
+         to handle it because it's about to be overwritten. */
+      ids->du = nullptr;
       ++ss->reuseds;
       ++g_stats.downstreamTimeouts;
+      handleDOHTimeout(du);
     }
 
     ids->cs = &cs;
+    ids->origFD = cs.udpFD;
     ids->origID = dh->id;
     setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
@@ -1604,7 +1641,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ++g_stats.downstreamSendErrors;
     }
 
-    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
+    vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
   }
   catch(const std::exception& e){
     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
@@ -1617,11 +1654,11 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
   struct MMReceiver
   {
     char packet[4096];
-    /* used by HarvestDestinationAddress */
-    char cbuf[256];
     ComboAddress remote;
     ComboAddress dest;
     struct iovec iov;
+    /* used by HarvestDestinationAddress */
+    cmsgbuf_aligned cbuf;
   };
   const size_t vectSize = g_udpVectorSize;
   /* the actual buffer is larger because:
@@ -1638,7 +1675,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
   /* initialize the structures needed to receive our messages */
   for (size_t idx = 0; idx < vectSize; idx++) {
     recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
-    fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, s_udpIncomingBufferSize, &recvData[idx].remote);
+    fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, s_udpIncomingBufferSize, &recvData[idx].remote);
   }
 
   /* go now */
@@ -1673,7 +1710,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
         continue;
       }
 
-      processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast<uint16_t>(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, recvData[msgIdx].cbuf);
+      processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast<uint16_t>(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, &recvData[msgIdx].cbuf);
 
     }
 
@@ -1717,12 +1754,12 @@ try
     struct msghdr msgh;
     struct iovec iov;
     /* used by HarvestDestinationAddress */
-    char cbuf[256];
+    cmsgbuf_aligned cbuf;
 
     ComboAddress remote;
     ComboAddress dest;
     remote.sin4.sin_family = cs->local.sin4.sin_family;
-    fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), packet, sizeof(packet), &remote);
+    fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), packet, sizeof(packet), &remote);
 
     for(;;) {
       ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
@@ -2053,29 +2090,28 @@ static void healthChecksThread()
       dss->prev.reuseds.store(dss->reuseds.load());
       
       for(IDState& ids  : dss->idStates) { // timeouts
-        int origFD = ids.origFD;
-        if(origFD >=0 && ids.age++ > g_udpTimeout) {
-          /* We set origFD to -1 as soon as possible
+        int64_t usageIndicator = ids.usageIndicator;
+        if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) {
+          /* We mark the state as unused as soon as possible
              to limit the risk of racing with the
              responder thread.
-             The UDP client thread only checks origFD to
-             know whether outstanding has to be incremented,
-             so the sooner the better any way since we _will_
-             decrement it.
           */
-          if (ids.origFD.exchange(-1) != origFD) {
+          auto oldDU = ids.du;
+
+          if (!ids.tryMarkUnused(usageIndicator)) {
             /* this state has been altered in the meantime,
                don't go anywhere near it */
             continue;
           }
           ids.du = nullptr;
+          handleDOHTimeout(oldDU);
           ids.age = 0;
           dss->reuseds++;
           --dss->outstanding;
           ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
           vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
                    dss->remote.toStringWithPort(), dss->name,
-                   ids.qname.toString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
+                   ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
 
           struct timespec ts;
           gettime(&ts);
@@ -2792,3 +2828,8 @@ catch(PDNSException &ae)
   errlog("Fatal pdns error: %s", ae.reason);
   _exit(EXIT_FAILURE);
 }
+
+uint64_t getLatencyCount(const std::string&)
+{
+    return g_stats.responses + g_stats.selfAnswered + g_stats.cacheHits;
+}