]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Wrap the DOHUnit object in a unique_ptr whenever possible
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 8 Dec 2021 11:31:00 +0000 (12:31 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 8 Dec 2021 11:31:00 +0000 (12:31 +0100)
pdns/dnsdist-idstate.hh
pdns/dnsdist.cc
pdns/dnsdistdist/doh.cc
pdns/doh.hh

index 043a56911f729099346178b751efc4bbf694fc70..a987cb45ae40b6f6c71df2a8b209fb6c0908549f 100644 (file)
@@ -241,7 +241,7 @@ struct IDState
   std::unique_ptr<QTag> qTag{nullptr}; // 8
   boost::optional<uint32_t> tempFailureTTL; // 8
   const ClientState* cs{nullptr}; // 8
-  DOHUnit* du{nullptr}; // 8
+  DOHUnit* du{nullptr}; // 8 (not a unique_ptr because we currently need to be able to peek at it without knowing taking ownership until later)
   std::atomic<int64_t> usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty   // 8
   std::atomic<uint32_t> generation{0}; // increased every time a state is used, to be able to detect an ABA issue    // 4
   uint32_t cacheKey{0}; // 4
index 9ce83c6681b9bbf200106d00262cc86ac491c4ed..964eb4b43d48045157aead570a14483fbc9ee252 100644 (file)
@@ -630,7 +630,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 
         /* read the potential DOHUnit state as soon as possible, but don't use it
            until we have confirmed that we own this state by updating usageIndicator */
-        auto du = ids->du;
+        auto du = std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>(ids->du, DOHUnit::release);
         /* setting age to 0 to prevent the maintainer thread from
            cleaning this IDS while we process the response.
         */
@@ -656,7 +656,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           --dss->outstanding;  // you'd think an attacker could game this, but we're using connected socket
         } else {
           /* someone updated the state in the meantime, we can't touch the existing pointer */
-          du = nullptr;
+          du.release();
           /* since the state has been updated, we can't safely access it so let's just drop
              this response */
           continue;
@@ -669,8 +669,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
           // DoH query, we cannot touch du after that
-          du->handleUDPResponse(std::move(response), std::move(*ids));
-          du = nullptr;
+          handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids));
 #endif
           continue;
         }
@@ -1525,20 +1524,20 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
     ids->age = 0;
-    DOHUnit* du = nullptr;
+    std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du(nullptr, DOHUnit::release);
 
     /* that means that the state was in use, possibly with an allocated
        DOHUnit that we will need to handle, but we can't touch it before
        confirming that we now own this state */
     if (ids->isInUse()) {
-      du = ids->du;
+      du = std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>(ids->du, DOHUnit::release);
     }
 
     /* we atomically replace the value, we now own this state */
     if (!ids->markAsUsed()) {
       /* the state was not in use.
          we reset 'du' because it might have still been in use when we read it. */
-      du = nullptr;
+      du.release();
       ++ss->outstanding;
     }
     else {
@@ -1547,8 +1546,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids->du = nullptr;
       ++ss->reuseds;
       ++g_stats.downstreamTimeouts;
-      handleDOHTimeout(du);
-      du = nullptr;
+      handleDOHTimeout(std::move(du));
     }
 
     ids->cs = &cs;
@@ -1888,7 +1886,7 @@ static void healthChecksThread()
             continue;
           }
           ids.du = nullptr;
-          handleDOHTimeout(oldDU);
+          handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>(oldDU, DOHUnit::release));
           oldDU = nullptr;
           ids.age = 0;
           dss->reuseds++;
index e7684bd5ee81b9f874498f8fa4a915e9d8c0a7ec..af1291d5a6d011829d45fb83c5fd58f337969a94 100644 (file)
@@ -225,11 +225,13 @@ struct DOHServerConfig
 
 /* This internal function sends back the object to the main thread to send a reply.
    The caller should NOT release or touch the unit after calling this function */
-static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description)
+static void sendDoHUnitToTheMainThread(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du, const char* description)
 {
-  static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
-  ssize_t sent = write(du->rsock, &du, sizeof(du));
-  if (sent != sizeof(du)) {
+  auto ptr = du.release();
+  static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+
+  ssize_t sent = write(ptr->rsock, &ptr, sizeof(ptr));
+  if (sent != sizeof(ptr)) {
     if (errno == EAGAIN || errno == EWOULDBLOCK) {
       ++g_stats.dohResponsePipeFull;
       vinfolog("Unable to pass a %s to the DoH worker thread because the pipe is full", description);
@@ -238,14 +240,14 @@ static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description)
       vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, stringerror());
     }
 
-    du->release();
+    ptr->release();
   }
 }
 
 /* This function is called from other threads than the main DoH one,
    instructing it to send a 502 error to the client.
    It takes ownership of the unit. */
-void handleDOHTimeout(DOHUnit* oldDU)
+void handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& oldDU)
 {
   if (oldDU == nullptr) {
     return;
@@ -254,7 +256,7 @@ void handleDOHTimeout(DOHUnit* oldDU)
 /* we are about to erase an existing DU */
   oldDU->status_code = 502;
 
-  sendDoHUnitToTheMainThread(oldDU, "DoH timeout");
+  sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout");
 }
 
 struct DOHConnection
@@ -411,17 +413,10 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo
 class DoHTCPCrossQuerySender : public TCPQuerySender
 {
 public:
-  DoHTCPCrossQuerySender(DOHUnit* du_): du(du_)
+  DoHTCPCrossQuerySender(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du_): du(std::move(du_))
   {
   }
 
-  ~DoHTCPCrossQuerySender()
-  {
-    if (du != nullptr) {
-      du->release();
-    }
-  }
-
   bool active() const override
   {
     return true;
@@ -455,8 +450,7 @@ public:
     memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
     if (!processResponse(du->response, localRespRuleActions, dr, false, false)) {
-      du->release();
-      du = nullptr;
+      du.reset();
       return;
     }
 
@@ -470,8 +464,7 @@ public:
       ++du->ids.cs->responses;
     }
 
-    sendDoHUnitToTheMainThread(du, "cross-protocol response");
-    du = nullptr;
+    sendDoHUnitToTheMainThread(std::move(du), "cross-protocol response");
   }
 
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
@@ -491,18 +484,17 @@ public:
 
     du->ids = std::move(query);
     du->status_code = 502;
-    sendDoHUnitToTheMainThread(du, "cross-protocol error response");
-    du = nullptr;
+    sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response");
   }
 
 private:
-  DOHUnit* du{nullptr};
+  std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du;
 };
 
 class DoHCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  DoHCrossProtocolQuery(DOHUnit* du_): du(du_)
+  DoHCrossProtocolQuery(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du_): du(std::move(du_))
   {
     query = InternalQuery(std::move(du->query), std::move(du->ids));
     /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
@@ -515,29 +507,26 @@ public:
     proxyProtocolPayloadSize = du->proxyProtocolPayloadSize;
   }
 
-  ~DoHCrossProtocolQuery()
+  void handleInternalError()
   {
-    if (du != nullptr) {
-      du->release();
-    }
+    du->status_code = 502;
+    sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
   }
 
   std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
   {
-    auto sender = std::make_shared<DoHTCPCrossQuerySender>(du);
-    du = nullptr;
+    auto sender = std::make_shared<DoHTCPCrossQuerySender>(std::move(du));
     return sender;
   }
 
 private:
-  DOHUnit* du{nullptr};
+  std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du;
 };
 
 /*
-   This function takes ownership of the DOHUnit.
    We are not in the main DoH thread but in the DoH 'client' thread.
 */
-static void processDOHQuery(DOHUnit* du)
+static void processDOHQuery(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du)
 {
   uint16_t queryId = 0;
   ComboAddress remote;
@@ -548,7 +537,7 @@ static void processDOHQuery(DOHUnit* du)
       // but we should be fine as long as we don't touch du->req
       // outside of the main DoH thread
       du->status_code = 500;
-      sendDoHUnitToTheMainThread(du, "DoH killed in flight");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH killed in flight");
       return;
     }
     remote = du->ids.origRemote;
@@ -559,7 +548,7 @@ static void processDOHQuery(DOHUnit* du)
     if (du->query.size() < sizeof(dnsheader)) {
       ++g_stats.nonCompliantQueries;
       du->status_code = 400;
-      sendDoHUnitToTheMainThread(du, "DoH non-compliant query");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH non-compliant query");
       return;
     }
 
@@ -578,7 +567,7 @@ static void processDOHQuery(DOHUnit* du)
 
       if (!checkQueryHeaders(dh)) {
         du->status_code = 400;
-        sendDoHUnitToTheMainThread(du, "DoH invalid headers");
+        sendDoHUnitToTheMainThread(std::move(du), "DoH invalid headers");
         return;
       }
 
@@ -587,7 +576,7 @@ static void processDOHQuery(DOHUnit* du)
         dh->qr = true;
         du->response = std::move(du->query);
 
-        sendDoHUnitToTheMainThread(du, "DoH empty query");
+        sendDoHUnitToTheMainThread(std::move(du), "DoH empty query");
         return;
       }
 
@@ -599,14 +588,15 @@ static void processDOHQuery(DOHUnit* du)
     DNSName qname(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
     DNSQuestion dq(&qname, qtype, qclass, &du->ids.origDest, &du->ids.origRemote, du->query, dnsdist::Protocol::DoH, &queryRealTime);
     dq.ednsAdded = du->ids.ednsAdded;
-    dq.du = du;
+    /* store the raw pointer */
+    dq.du = du.get();
     dq.sni = std::move(du->sni);
 
     auto result = processQuery(dq, cs, holders, du->downstream);
 
     if (result == ProcessQueryResult::Drop) {
       du->status_code = 403;
-      sendDoHUnitToTheMainThread(du, "DoH dropped query");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH dropped query");
       return;
     }
 
@@ -614,19 +604,19 @@ static void processDOHQuery(DOHUnit* du)
       if (du->response.empty()) {
         du->response = std::move(du->query);
       }
-      sendDoHUnitToTheMainThread(du, "DoH self-answered response");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH self-answered response");
       return;
     }
 
     if (result != ProcessQueryResult::PassToBackend) {
       du->status_code = 500;
-      sendDoHUnitToTheMainThread(du, "DoH no backend available");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available");
       return;
     }
 
     if (du->downstream == nullptr) {
       du->status_code = 502;
-      sendDoHUnitToTheMainThread(du, "DoH no backend available");
+      sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available");
       return;
     }
 
@@ -642,20 +632,18 @@ static void processDOHQuery(DOHUnit* du)
       du->ids.cs = &cs;
       setIDStateFromDNSQuestion(du->ids, dq, std::move(qname));
 
-      /* we increment the ref counter because we store a copy in the DoHCrossProtocolQuery object */
-      du->get();
+      du->tcp = true;
+      std::shared_ptr<DownstreamState>& downstream = du->downstream;
+
       /* this moves du->ids, careful! */
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
       cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
-      du->tcp = true;
-      if (du->downstream->passCrossProtocolQuery(std::move(cpq))) {
-        du->release();
+
+      if (downstream->passCrossProtocolQuery(std::move(cpq))) {
         return;
       }
       else {
-        /* only release du once here, since it also belongs to the DoHCrossProtocolQuery object */
-        du->status_code = 502;
-        sendDoHUnitToTheMainThread(du, "DoH internal error");
+        cpq->handleInternalError();
         return;
       }
     }
@@ -686,14 +674,15 @@ static void processDOHQuery(DOHUnit* du)
          to handle it because it's about to be overwritten. */
       ++du->downstream->reuseds;
       ++g_stats.downstreamTimeouts;
-      handleDOHTimeout(oldDU);
+      handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>(oldDU, DOHUnit::release));
     }
 
     ids->origFD = 0;
     /* increase the ref count since we are about to store the pointer */
     du->get();
     duRefCountIncremented = true;
-    ids->du = du;
+    /* store the raw pointer */
+    ids->du = du.get();
 
     ids->cs = &cs;
     ids->origID = htons(queryId);
@@ -741,7 +730,7 @@ static void processDOHQuery(DOHUnit* du)
         ++du->downstream->sendErrors;
         ++g_stats.downstreamSendErrors;
         du->status_code = 502;
-        sendDoHUnitToTheMainThread(du, "DoH internal error");
+        sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
         return;
       }
     }
@@ -757,11 +746,10 @@ static void processDOHQuery(DOHUnit* du)
   catch (const std::exception& e) {
     vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
     du->status_code = 500;
-    sendDoHUnitToTheMainThread(du, "DoH internal error");
+    sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
     return;
   }
 
-  du->release();
   return;
 }
 
@@ -1239,16 +1227,17 @@ static void dnsdistclient(int qsock)
 
   for(;;) {
     try {
-      DOHUnit* du = nullptr;
-      ssize_t got = read(qsock, &du, sizeof(du));
+      DOHUnit* ptr = nullptr;
+      ssize_t got = read(qsock, &ptr, sizeof(ptr));
       if (got < 0) {
         warnlog("Error receiving internal DoH query: %s", strerror(errno));
         continue;
       }
-      else if (static_cast<size_t>(got) < sizeof(du)) {
+      else if (static_cast<size_t>(got) < sizeof(ptr)) {
         continue;
       }
 
+      std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du(ptr, DOHUnit::release);
       /* we are not in the main DoH thread anymore, so there is a real risk of
          a race condition where h2o kills the query while we are processing it,
          so we can't touch the content of du->req until we are back into the
@@ -1256,7 +1245,6 @@ static void dnsdistclient(int qsock)
       if (!du->req) {
         // it got killed in flight already
         du->self = nullptr;
-        du->release();
         continue;
       }
 
@@ -1275,14 +1263,12 @@ static void dnsdistclient(int qsock)
         // we leave existing EDNS in place
       }
 
-      /* we transfer the ownership of du to this function */
-      processDOHQuery(du);
-      du = nullptr;
+      processDOHQuery(std::move(du));
     }
-    catch(const std::exception& e) {
+    catch (const std::exception& e) {
       errlog("Error while processing query received over DoH: %s", e.what());
     }
-    catch(...) {
+    catch (...) {
       errlog("Unspecified error while processing query received over DoH");
     }
   }
@@ -1303,9 +1289,9 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
      anyway, otherwise queries and responses are piling up in our pipes, consuming
      memory and likely coming up too late after the client has gone away */
   while (true) {
-    DOHUnit *du = nullptr;
+    DOHUnit *ptr = nullptr;
     DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(listener->data);
-    ssize_t got = read(dsc->dohresponsepair[1], &du, sizeof(du));
+    ssize_t got = read(dsc->dohresponsepair[1], &ptr, sizeof(ptr));
 
     if (got < 0) {
       if (errno != EWOULDBLOCK && errno != EAGAIN) {
@@ -1313,14 +1299,14 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
       }
       return;
     }
-    else if (static_cast<size_t>(got) != sizeof(du)) {
-      errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(du));
+    else if (static_cast<size_t>(got) != sizeof(ptr)) {
+      errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(ptr));
       return;
     }
 
+    std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du(ptr, DOHUnit::release);
     if (!du->req) { // it got killed in flight
       du->self = nullptr;
-      du->release();
       continue;
     }
 
@@ -1329,15 +1315,14 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
       dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data() + du->proxyProtocolPayloadSize);
       queryDH->id = du->ids.origID;
 
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
       du->tcp = true;
       du->truncated = false;
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
 
       if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
         continue;
       }
       else {
-        du->release();
         vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
         continue;
       }
@@ -1351,8 +1336,6 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
     }
 
     handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true);
-
-    du->release();
   }
 }
 
@@ -1649,38 +1632,37 @@ void dohThread(ClientState* cs)
   }
 }
 
-void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state)
+void handleUDPResponseForDoH(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du, PacketBuffer&& udpResponse, IDState&& state)
 {
-  response = std::move(udpResponse);
-  ids = std::move(state);
+  du->response = std::move(udpResponse);
+  du->ids = std::move(state);
 
-  const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data());
+  const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
   if (!dh->tc) {
     thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
-    DNSResponse dr = makeDNSResponseFromIDState(ids, response);
+    DNSResponse dr = makeDNSResponseFromIDState(du->ids, du->response);
     dnsheader cleartextDH;
     memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
-    if (!processResponse(response, localRespRuleActions, dr, false, true)) {
-      release();
+    if (!processResponse(du->response, localRespRuleActions, dr, false, true)) {
       return;
     }
 
-    double udiff = ids.sentTime.udiff();
-    vinfolog("Got answer from %s, relayed to %s (https), took %f usec", downstream->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+    double udiff = du->ids.sentTime.udiff();
+    vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
 
-    handleResponseSent(ids, udiff, *dr.remote, downstream->remote, response.size(), cleartextDH, downstream->getProtocol());
+    handleResponseSent(du->ids, udiff, *dr.remote, du->downstream->remote, du->response.size(), cleartextDH, du->downstream->getProtocol());
 
     ++g_stats.responses;
-    if (ids.cs) {
-      ++ids.cs->responses;
+    if (du->ids.cs) {
+      ++du->ids.cs->responses;
     }
   }
   else {
-    truncated = true;
+    du->truncated = true;
   }
 
-  sendDoHUnitToTheMainThread(this, "DoH response");
+  sendDoHUnitToTheMainThread(std::move(du), "DoH response");
 }
 
 #else /* HAVE_DNS_OVER_HTTPS */
index 16210d94933f09cf52cddf4b26e2b7be2b49d72f..65fc6d604a9a6b5b39dcd09579946d63eee5f251 100644 (file)
@@ -173,6 +173,9 @@ struct DOHFrontend
 #ifndef HAVE_DNS_OVER_HTTPS
 struct DOHUnit
 {
+  static void release(DOHUnit* ptr)
+  {
+  }
 };
 
 #else /* HAVE_DNS_OVER_HTTPS */
@@ -208,7 +211,12 @@ struct DOHUnit
     }
   }
 
-  void handleUDPResponse(PacketBuffer&& response, IDState&& state);
+  static void release(DOHUnit* ptr)
+  {
+    if (ptr) {
+      ptr->release();
+    }
+  }
 
   IDState ids;
   std::string sni;
@@ -248,6 +256,8 @@ struct DOHUnit
   void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="");
 };
 
+void handleUDPResponseForDoH(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&&, PacketBuffer&& response, IDState&& state);
+
 #endif /* HAVE_DNS_OVER_HTTPS  */
 
-void handleDOHTimeout(DOHUnit* oldDU);
+void handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& oldDU);