]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactoring of the DoH unit handling
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 8 Dec 2021 10:15:08 +0000 (11:15 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 8 Dec 2021 10:15:08 +0000 (11:15 +0100)
pdns/dnsdist.cc
pdns/dnsdistdist/doh.cc

index 4a18e7e7ad0b1129e0a48c0e1e6b6c106632b844..9ce83c6681b9bbf200106d00262cc86ac491c4ed 100644 (file)
@@ -668,8 +668,9 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         /* don't call processResponse for DOH */
         if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
-          // DoH query
+          // DoH query, we cannot touch du after that
           du->handleUDPResponse(std::move(response), std::move(*ids));
+          du = nullptr;
 #endif
           continue;
         }
@@ -1547,6 +1548,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ++ss->reuseds;
       ++g_stats.downstreamTimeouts;
       handleDOHTimeout(du);
+      du = nullptr;
     }
 
     ids->cs = &cs;
@@ -1887,6 +1889,7 @@ static void healthChecksThread()
           }
           ids.du = nullptr;
           handleDOHTimeout(oldDU);
+          oldDU = nullptr;
           ids.age = 0;
           dss->reuseds++;
           --dss->outstanding;
index 9986ab2c98cb34d0af13047ee822c97feecd8d3b..e7684bd5ee81b9f874498f8fa4a915e9d8c0a7ec 100644 (file)
@@ -223,12 +223,10 @@ struct DOHServerConfig
   int dohresponsepair[2]{-1,-1};
 };
 
-
+/* This internal function sends back the object to the main thread to send a reply.
+   The caller should NOT release or touch the unit after calling this function */
 static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description)
 {
-  /* increase the ref counter before sending the pointer */
-  du->get();
-
   static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
   ssize_t sent = write(du->rsock, &du, sizeof(du));
   if (sent != sizeof(du)) {
@@ -245,7 +243,8 @@ static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description)
 }
 
 /* This function is called from other threads than the main DoH one,
-   instructing it to send a 502 error to the client */
+   instructing it to send a 502 error to the client.
+   It takes ownership of the unit. */
 void handleDOHTimeout(DOHUnit* oldDU)
 {
   if (oldDU == nullptr) {
@@ -256,8 +255,6 @@ void handleDOHTimeout(DOHUnit* oldDU)
   oldDU->status_code = 502;
 
   sendDoHUnitToTheMainThread(oldDU, "DoH timeout");
-
-  oldDU->release();
 }
 
 struct DOHConnection
@@ -474,7 +471,6 @@ public:
     }
 
     sendDoHUnitToTheMainThread(du, "cross-protocol response");
-    du->release();
     du = nullptr;
   }
 
@@ -496,7 +492,6 @@ public:
     du->ids = std::move(query);
     du->status_code = 502;
     sendDoHUnitToTheMainThread(du, "cross-protocol error response");
-    du->release();
     du = nullptr;
   }
 
@@ -539,11 +534,10 @@ private:
 };
 
 /*
-   this function calls 'return -1' to drop a query without sending it
-   caller should make sure HTTPS thread hears of that
+   This function takes ownership of the DOHUnit.
    We are not in the main DoH thread but in the DoH 'client' thread.
 */
-static int processDOHQuery(DOHUnit* du)
+static void processDOHQuery(DOHUnit* du)
 {
   uint16_t queryId = 0;
   ComboAddress remote;
@@ -553,7 +547,9 @@ static int processDOHQuery(DOHUnit* du)
       // we got closed meanwhile. XXX small race condition here
       // but we should be fine as long as we don't touch du->req
       // outside of the main DoH thread
-      return -1;
+      du->status_code = 500;
+      sendDoHUnitToTheMainThread(du, "DoH killed in flight");
+      return;
     }
     remote = du->ids.origRemote;
     DOHServerConfig* dsc = du->dsc;
@@ -563,7 +559,8 @@ static int processDOHQuery(DOHUnit* du)
     if (du->query.size() < sizeof(dnsheader)) {
       ++g_stats.nonCompliantQueries;
       du->status_code = 400;
-      return -1;
+      sendDoHUnitToTheMainThread(du, "DoH non-compliant query");
+      return;
     }
 
     ++cs.queries;
@@ -581,7 +578,8 @@ static int processDOHQuery(DOHUnit* du)
 
       if (!checkQueryHeaders(dh)) {
         du->status_code = 400;
-        return -1; // drop
+        sendDoHUnitToTheMainThread(du, "DoH invalid headers");
+        return;
       }
 
       if (dh->qdcount == 0) {
@@ -589,9 +587,8 @@ static int processDOHQuery(DOHUnit* du)
         dh->qr = true;
         du->response = std::move(du->query);
 
-        sendDoHUnitToTheMainThread(du, "DoH self-answered response");
-
-        return 0;
+        sendDoHUnitToTheMainThread(du, "DoH empty query");
+        return;
       }
 
       queryId = ntohs(dh->id);
@@ -609,7 +606,8 @@ static int processDOHQuery(DOHUnit* du)
 
     if (result == ProcessQueryResult::Drop) {
       du->status_code = 403;
-      return -1;
+      sendDoHUnitToTheMainThread(du, "DoH dropped query");
+      return;
     }
 
     if (result == ProcessQueryResult::SendAnswer) {
@@ -617,18 +615,19 @@ static int processDOHQuery(DOHUnit* du)
         du->response = std::move(du->query);
       }
       sendDoHUnitToTheMainThread(du, "DoH self-answered response");
-
-      return 0;
+      return;
     }
 
     if (result != ProcessQueryResult::PassToBackend) {
       du->status_code = 500;
-      return -1;
+      sendDoHUnitToTheMainThread(du, "DoH no backend available");
+      return;
     }
 
     if (du->downstream == nullptr) {
       du->status_code = 502;
-      return -1;
+      sendDoHUnitToTheMainThread(du, "DoH no backend available");
+      return;
     }
 
     if (du->downstream->isTCPOnly()) {
@@ -643,18 +642,21 @@ static int processDOHQuery(DOHUnit* du)
       du->ids.cs = &cs;
       setIDStateFromDNSQuestion(du->ids, dq, std::move(qname));
 
-      /* this moves du->ids, careful! */
+      /* we increment the ref counter because we store a copy in the DoHCrossProtocolQuery object */
       du->get();
+      /* this moves du->ids, careful! */
       auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
       cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
       du->tcp = true;
       if (du->downstream->passCrossProtocolQuery(std::move(cpq))) {
-        return 0;
+        du->release();
+        return;
       }
       else {
-        /* do not release du here, it belongs to the DoHCrossProtocolQuery object */
+        /* only release du once here, since it also belongs to the DoHCrossProtocolQuery object */
         du->status_code = 502;
-        return -1;
+        sendDoHUnitToTheMainThread(du, "DoH internal error");
+        return;
       }
     }
 
@@ -739,7 +741,8 @@ static int processDOHQuery(DOHUnit* du)
         ++du->downstream->sendErrors;
         ++g_stats.downstreamSendErrors;
         du->status_code = 502;
-        return -1;
+        sendDoHUnitToTheMainThread(du, "DoH internal error");
+        return;
       }
     }
     catch (const std::exception& e) {
@@ -751,13 +754,15 @@ static int processDOHQuery(DOHUnit* du)
 
     vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), du->downstream->getName());
   }
-  catch(const std::exception& e) {
+  catch (const std::exception& e) {
     vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
     du->status_code = 500;
-    return -1;
+    sendDoHUnitToTheMainThread(du, "DoH internal error");
+    return;
   }
 
-  return 0;
+  du->release();
+  return;
 }
 
 /* called when a HTTP response is about to be sent, from the main DoH thread */
@@ -873,8 +878,10 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re
         h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0);
       }
     }
-    catch(...) {
-      ptr->release();
+    catch (...) {
+      if (ptr != nullptr) {
+        ptr->release();
+      }
     }
   }
   catch(const std::exception& e) {
@@ -1268,13 +1275,9 @@ static void dnsdistclient(int qsock)
         // we leave existing EDNS in place
       }
 
-      if (processDOHQuery(du) < 0) {
-        du->status_code = 500;
-
-        sendDoHUnitToTheMainThread(du, "DoH internal error");
-        // XXX if we failed to send it to the main thread, now what - will h2o eventually time this out for us
-      }
-      du->release();
+      /* we transfer the ownership of du to this function */
+      processDOHQuery(du);
+      du = nullptr;
     }
     catch(const std::exception& e) {
       errlog("Error while processing query received over DoH: %s", e.what());
@@ -1678,8 +1681,6 @@ void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state)
   }
 
   sendDoHUnitToTheMainThread(this, "DoH response");
-  /* the reference counter has been incremented in sendDoHUnitToTheMainThread */
-  release();
 }
 
 #else /* HAVE_DNS_OVER_HTTPS */