]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor the handling of responses for UDP clients
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 24 Nov 2022 16:21:48 +0000 (17:21 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 14 Dec 2022 15:25:56 +0000 (16:25 +0100)
pdns/dnsdist.cc

index 0e701431b375996a4276a4715012d59b33e43c80..7348fa95a0adba8e33a4aa2e8043b7b62981d0a2 100644 (file)
@@ -636,6 +636,64 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff,
   doLatencyStats(incomingProtocol, udiff);
 }
 
+static void handleResponseForUDPClient(IDState& ids, PacketBuffer& response, uint16_t maxPayloadSize, const std::vector<DNSDistResponseRuleAction>& respRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, const std::shared_ptr<DownstreamState>& ds, bool selfGenerated, std::optional<uint16_t> queryId)
+{
+  DNSResponse dr = makeDNSResponseFromIDState(ids, response);
+
+  if (maxPayloadSize > 0 && response.size() > maxPayloadSize) {
+    vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), maxPayloadSize);
+    truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.qname->wirelength());
+    dr.getHeader()->tc = true;
+  }
+  else if (dr.getHeader()->tc && g_truncateTC) {
+    truncateTC(response, dr.getMaximumSize(), dr.qname->wirelength());
+  }
+
+  /* when the answer is encrypted in place, we need to get a copy
+     of the original header before encryption to fill the ring buffer */
+  dnsheader cleartextDH;
+  memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
+
+  if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted, true)) {
+    if (queryId) {
+      ds->releaseState(*queryId);
+    }
+    return;
+  }
+
+  ++g_stats.responses;
+  if (ids.cs) {
+    ++ids.cs->responses;
+  }
+
+  bool muted = true;
+  if (ids.cs && !ids.cs->muted) {
+    ComboAddress empty;
+    empty.sin4.sin_family = 0;
+    sendUDPResponse(ids.cs->udpFD, response, dr.delayMsec, ids.hopLocal, ids.hopRemote);
+    muted = false;
+  }
+
+  if (!selfGenerated) {
+    double udiff = ids.sentTime.udiff();
+    if (!muted) {
+      vinfolog("Got answer from %s, relayed to %s (UDP), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+    }
+    else {
+      vinfolog("Got answer from %s, NOT relayed to %s (UDP) since that frontend is muted, took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+    }
+
+    handleResponseSent(ids, udiff, *dr.remote, ds->d_config.remote, response.size(), cleartextDH, ds->getProtocol());
+  }
+  else {
+    handleResponseSent(ids, 0., *dr.remote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP);
+  }
+
+  if (queryId) {
+    ds->releaseState(*queryId);
+  }
+}
+
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
 void responderThread(std::shared_ptr<DownstreamState> dss)
 {
@@ -645,10 +703,6 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
   auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
   const size_t initialBufferSize = getInitialUDPPacketBufferSize();
   PacketBuffer response(initialBufferSize);
-
-  /* when the answer is encrypted in place, we need to get a copy
-     of the original header before encryption to fill the ring buffer */
-  dnsheader cleartextDH;
   uint16_t queryId = 0;
   std::vector<int> sockets;
   sockets.reserve(dss->sockets.size());
@@ -699,7 +753,6 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
            cleaning this IDS while we process the response.
         */
         ids->age = 0;
-        int origFD = ids->origFD;
 
         unsigned int qnameWireLength = 0;
         if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
@@ -744,33 +797,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           continue;
         }
 
-        DNSResponse dr = makeDNSResponseFromIDState(*ids, response);
-        if (dh->tc && g_truncateTC) {
-          truncateTC(response, dr.getMaximumSize(), qnameWireLength);
-        }
-        memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
-
-        if (!processResponse(response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, ids->cs && ids->cs->muted, true)) {
-          dss->releaseState(queryId);
-          continue;
-        }
-
-        ++g_stats.responses;
-        if (ids->cs) {
-          ++ids->cs->responses;
-        }
-
-        if (ids->cs && !ids->cs->muted) {
-          ComboAddress empty;
-          empty.sin4.sin_family = 0;
-          sendUDPResponse(origFD, response, dr.delayMsec, ids->hopLocal, ids->hopRemote);
-        }
-
-        udiff = ids->sentTime.udiff();
-        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->d_config.remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
-
-        handleResponseSent(*ids, udiff, *dr.remote, dss->d_config.remote, static_cast<unsigned int>(got), cleartextDH, dss->getProtocol());
-        dss->releaseState(queryId);
+        handleResponseForUDPClient(*ids, response, 0, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId);
       }
     }
     catch (const std::exception& e){
@@ -1412,7 +1439,7 @@ public:
 
   void handleResponse(const struct timeval& now, TCPResponse&& response) override
   {
-    if (!d_ds) {
+    if (!d_ds && !response.d_selfGenerated) {
       throw std::runtime_error("Passing a cross-protocol answer originated from UDP without a valid downstream");
     }
 
@@ -1420,35 +1447,13 @@ public:
 
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
-    DNSResponse dr = makeDNSResponseFromIDState(ids, response.d_buffer);
-    if (response.d_buffer.size() > d_payloadSize) {
-      vinfolog("Got a response of size %d over TCP, while the initial UDP payload size was %d, truncating", response.d_buffer.size(), d_payloadSize);
-      truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.qname->wirelength());
-      dr.getHeader()->tc = true;
-    }
-
-    dnsheader cleartextDH;
-    memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
-
-    if (!processResponse(response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false, true)) {
-      return;
-    }
-
-    ++g_stats.responses;
-    if (ids.cs) {
-      ++ids.cs->responses;
-    }
-
-    if (ids.cs && !ids.cs->muted) {
-      ComboAddress empty;
-      empty.sin4.sin_family = 0;
-      sendUDPResponse(ids.origFD, response.d_buffer, dr.delayMsec, ids.hopLocal, ids.hopRemote);
-    }
 
     double udiff = ids.sentTime.udiff();
-    vinfolog("Got answer from %s, relayed to %s (UDP), took %f usec", d_ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+    if (d_ds && !response.d_selfGenerated) {
+      vinfolog("Got answer from %s, relayed to %s (UDP), took %f usec", d_ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+    }
 
-    handleResponseSent(ids, udiff, *dr.remote, d_ds->d_config.remote, response.d_buffer.size(), cleartextDH, d_ds->getProtocol());
+    handleResponseForUDPClient(ids, response.d_buffer, d_payloadSize, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt);
   }
 
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override