From: Remi Gacogne Date: Wed, 8 Dec 2021 11:31:00 +0000 (+0100) Subject: dnsdist: Wrap the DOHUnit object in a unique_ptr whenever possible X-Git-Tag: auth-4.7.0-alpha1~122^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=871e5e481310f1e51bfb76319d5f49a27d21804d;p=thirdparty%2Fpdns.git dnsdist: Wrap the DOHUnit object in a unique_ptr whenever possible --- diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index 043a56911f..a987cb45ae 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -241,7 +241,7 @@ struct IDState std::unique_ptr qTag{nullptr}; // 8 boost::optional tempFailureTTL; // 8 const ClientState* cs{nullptr}; // 8 - DOHUnit* du{nullptr}; // 8 + DOHUnit* du{nullptr}; // 8 (not a unique_ptr because we currently need to be able to peek at it without knowing taking ownership until later) std::atomic usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8 std::atomic generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4 uint32_t cacheKey{0}; // 4 diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 9ce83c6681..964eb4b43d 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -630,7 +630,7 @@ void responderThread(std::shared_ptr dss) /* read the potential DOHUnit state as soon as possible, but don't use it until we have confirmed that we own this state by updating usageIndicator */ - auto du = ids->du; + auto du = std::unique_ptr(ids->du, DOHUnit::release); /* setting age to 0 to prevent the maintainer thread from cleaning this IDS while we process the response. */ @@ -656,7 +656,7 @@ void responderThread(std::shared_ptr dss) --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket } else { /* someone updated the state in the meantime, we can't touch the existing pointer */ - du = nullptr; + du.release(); /* since the state has been updated, we can't safely access it so let's just drop this response */ continue; @@ -669,8 +669,7 @@ void responderThread(std::shared_ptr dss) if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query, we cannot touch du after that - du->handleUDPResponse(std::move(response), std::move(*ids)); - du = nullptr; + handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids)); #endif continue; } @@ -1525,20 +1524,20 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct unsigned int idOffset = (ss->idOffset++) % ss->idStates.size(); IDState* ids = &ss->idStates[idOffset]; ids->age = 0; - DOHUnit* du = nullptr; + std::unique_ptr du(nullptr, DOHUnit::release); /* that means that the state was in use, possibly with an allocated DOHUnit that we will need to handle, but we can't touch it before confirming that we now own this state */ if (ids->isInUse()) { - du = ids->du; + du = std::unique_ptr(ids->du, DOHUnit::release); } /* we atomically replace the value, we now own this state */ if (!ids->markAsUsed()) { /* the state was not in use. we reset 'du' because it might have still been in use when we read it. */ - du = nullptr; + du.release(); ++ss->outstanding; } else { @@ -1547,8 +1546,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->du = nullptr; ++ss->reuseds; ++g_stats.downstreamTimeouts; - handleDOHTimeout(du); - du = nullptr; + handleDOHTimeout(std::move(du)); } ids->cs = &cs; @@ -1888,7 +1886,7 @@ static void healthChecksThread() continue; } ids.du = nullptr; - handleDOHTimeout(oldDU); + handleDOHTimeout(std::unique_ptr(oldDU, DOHUnit::release)); oldDU = nullptr; ids.age = 0; dss->reuseds++; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index e7684bd5ee..af1291d5a6 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -225,11 +225,13 @@ struct DOHServerConfig /* This internal function sends back the object to the main thread to send a reply. The caller should NOT release or touch the unit after calling this function */ -static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description) +static void sendDoHUnitToTheMainThread(std::unique_ptr&& du, const char* description) { - static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(du->rsock, &du, sizeof(du)); - if (sent != sizeof(du)) { + auto ptr = du.release(); + static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + + ssize_t sent = write(ptr->rsock, &ptr, sizeof(ptr)); + if (sent != sizeof(ptr)) { if (errno == EAGAIN || errno == EWOULDBLOCK) { ++g_stats.dohResponsePipeFull; vinfolog("Unable to pass a %s to the DoH worker thread because the pipe is full", description); @@ -238,14 +240,14 @@ static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description) vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, stringerror()); } - du->release(); + ptr->release(); } } /* This function is called from other threads than the main DoH one, instructing it to send a 502 error to the client. It takes ownership of the unit. */ -void handleDOHTimeout(DOHUnit* oldDU) +void handleDOHTimeout(std::unique_ptr&& oldDU) { if (oldDU == nullptr) { return; @@ -254,7 +256,7 @@ void handleDOHTimeout(DOHUnit* oldDU) /* we are about to erase an existing DU */ oldDU->status_code = 502; - sendDoHUnitToTheMainThread(oldDU, "DoH timeout"); + sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout"); } struct DOHConnection @@ -411,17 +413,10 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo class DoHTCPCrossQuerySender : public TCPQuerySender { public: - DoHTCPCrossQuerySender(DOHUnit* du_): du(du_) + DoHTCPCrossQuerySender(std::unique_ptr&& du_): du(std::move(du_)) { } - ~DoHTCPCrossQuerySender() - { - if (du != nullptr) { - du->release(); - } - } - bool active() const override { return true; @@ -455,8 +450,7 @@ public: memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); if (!processResponse(du->response, localRespRuleActions, dr, false, false)) { - du->release(); - du = nullptr; + du.reset(); return; } @@ -470,8 +464,7 @@ public: ++du->ids.cs->responses; } - sendDoHUnitToTheMainThread(du, "cross-protocol response"); - du = nullptr; + sendDoHUnitToTheMainThread(std::move(du), "cross-protocol response"); } void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override @@ -491,18 +484,17 @@ public: du->ids = std::move(query); du->status_code = 502; - sendDoHUnitToTheMainThread(du, "cross-protocol error response"); - du = nullptr; + sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response"); } private: - DOHUnit* du{nullptr}; + std::unique_ptr du; }; class DoHCrossProtocolQuery : public CrossProtocolQuery { public: - DoHCrossProtocolQuery(DOHUnit* du_): du(du_) + DoHCrossProtocolQuery(std::unique_ptr&& du_): du(std::move(du_)) { query = InternalQuery(std::move(du->query), std::move(du->ids)); /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload, @@ -515,29 +507,26 @@ public: proxyProtocolPayloadSize = du->proxyProtocolPayloadSize; } - ~DoHCrossProtocolQuery() + void handleInternalError() { - if (du != nullptr) { - du->release(); - } + du->status_code = 502; + sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); } std::shared_ptr getTCPQuerySender() override { - auto sender = std::make_shared(du); - du = nullptr; + auto sender = std::make_shared(std::move(du)); return sender; } private: - DOHUnit* du{nullptr}; + std::unique_ptr du; }; /* - This function takes ownership of the DOHUnit. We are not in the main DoH thread but in the DoH 'client' thread. */ -static void processDOHQuery(DOHUnit* du) +static void processDOHQuery(std::unique_ptr&& du) { uint16_t queryId = 0; ComboAddress remote; @@ -548,7 +537,7 @@ static void processDOHQuery(DOHUnit* du) // but we should be fine as long as we don't touch du->req // outside of the main DoH thread du->status_code = 500; - sendDoHUnitToTheMainThread(du, "DoH killed in flight"); + sendDoHUnitToTheMainThread(std::move(du), "DoH killed in flight"); return; } remote = du->ids.origRemote; @@ -559,7 +548,7 @@ static void processDOHQuery(DOHUnit* du) if (du->query.size() < sizeof(dnsheader)) { ++g_stats.nonCompliantQueries; du->status_code = 400; - sendDoHUnitToTheMainThread(du, "DoH non-compliant query"); + sendDoHUnitToTheMainThread(std::move(du), "DoH non-compliant query"); return; } @@ -578,7 +567,7 @@ static void processDOHQuery(DOHUnit* du) if (!checkQueryHeaders(dh)) { du->status_code = 400; - sendDoHUnitToTheMainThread(du, "DoH invalid headers"); + sendDoHUnitToTheMainThread(std::move(du), "DoH invalid headers"); return; } @@ -587,7 +576,7 @@ static void processDOHQuery(DOHUnit* du) dh->qr = true; du->response = std::move(du->query); - sendDoHUnitToTheMainThread(du, "DoH empty query"); + sendDoHUnitToTheMainThread(std::move(du), "DoH empty query"); return; } @@ -599,14 +588,15 @@ static void processDOHQuery(DOHUnit* du) DNSName qname(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); DNSQuestion dq(&qname, qtype, qclass, &du->ids.origDest, &du->ids.origRemote, du->query, dnsdist::Protocol::DoH, &queryRealTime); dq.ednsAdded = du->ids.ednsAdded; - dq.du = du; + /* store the raw pointer */ + dq.du = du.get(); dq.sni = std::move(du->sni); auto result = processQuery(dq, cs, holders, du->downstream); if (result == ProcessQueryResult::Drop) { du->status_code = 403; - sendDoHUnitToTheMainThread(du, "DoH dropped query"); + sendDoHUnitToTheMainThread(std::move(du), "DoH dropped query"); return; } @@ -614,19 +604,19 @@ static void processDOHQuery(DOHUnit* du) if (du->response.empty()) { du->response = std::move(du->query); } - sendDoHUnitToTheMainThread(du, "DoH self-answered response"); + sendDoHUnitToTheMainThread(std::move(du), "DoH self-answered response"); return; } if (result != ProcessQueryResult::PassToBackend) { du->status_code = 500; - sendDoHUnitToTheMainThread(du, "DoH no backend available"); + sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available"); return; } if (du->downstream == nullptr) { du->status_code = 502; - sendDoHUnitToTheMainThread(du, "DoH no backend available"); + sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available"); return; } @@ -642,20 +632,18 @@ static void processDOHQuery(DOHUnit* du) du->ids.cs = &cs; setIDStateFromDNSQuestion(du->ids, dq, std::move(qname)); - /* we increment the ref counter because we store a copy in the DoHCrossProtocolQuery object */ - du->get(); + du->tcp = true; + std::shared_ptr& downstream = du->downstream; + /* this moves du->ids, careful! */ - auto cpq = std::make_unique(du); + auto cpq = std::make_unique(std::move(du)); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - du->tcp = true; - if (du->downstream->passCrossProtocolQuery(std::move(cpq))) { - du->release(); + + if (downstream->passCrossProtocolQuery(std::move(cpq))) { return; } else { - /* only release du once here, since it also belongs to the DoHCrossProtocolQuery object */ - du->status_code = 502; - sendDoHUnitToTheMainThread(du, "DoH internal error"); + cpq->handleInternalError(); return; } } @@ -686,14 +674,15 @@ static void processDOHQuery(DOHUnit* du) to handle it because it's about to be overwritten. */ ++du->downstream->reuseds; ++g_stats.downstreamTimeouts; - handleDOHTimeout(oldDU); + handleDOHTimeout(std::unique_ptr(oldDU, DOHUnit::release)); } ids->origFD = 0; /* increase the ref count since we are about to store the pointer */ du->get(); duRefCountIncremented = true; - ids->du = du; + /* store the raw pointer */ + ids->du = du.get(); ids->cs = &cs; ids->origID = htons(queryId); @@ -741,7 +730,7 @@ static void processDOHQuery(DOHUnit* du) ++du->downstream->sendErrors; ++g_stats.downstreamSendErrors; du->status_code = 502; - sendDoHUnitToTheMainThread(du, "DoH internal error"); + sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); return; } } @@ -757,11 +746,10 @@ static void processDOHQuery(DOHUnit* du) catch (const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); du->status_code = 500; - sendDoHUnitToTheMainThread(du, "DoH internal error"); + sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); return; } - du->release(); return; } @@ -1239,16 +1227,17 @@ static void dnsdistclient(int qsock) for(;;) { try { - DOHUnit* du = nullptr; - ssize_t got = read(qsock, &du, sizeof(du)); + DOHUnit* ptr = nullptr; + ssize_t got = read(qsock, &ptr, sizeof(ptr)); if (got < 0) { warnlog("Error receiving internal DoH query: %s", strerror(errno)); continue; } - else if (static_cast(got) < sizeof(du)) { + else if (static_cast(got) < sizeof(ptr)) { continue; } + std::unique_ptr du(ptr, DOHUnit::release); /* we are not in the main DoH thread anymore, so there is a real risk of a race condition where h2o kills the query while we are processing it, so we can't touch the content of du->req until we are back into the @@ -1256,7 +1245,6 @@ static void dnsdistclient(int qsock) if (!du->req) { // it got killed in flight already du->self = nullptr; - du->release(); continue; } @@ -1275,14 +1263,12 @@ static void dnsdistclient(int qsock) // we leave existing EDNS in place } - /* we transfer the ownership of du to this function */ - processDOHQuery(du); - du = nullptr; + processDOHQuery(std::move(du)); } - catch(const std::exception& e) { + catch (const std::exception& e) { errlog("Error while processing query received over DoH: %s", e.what()); } - catch(...) { + catch (...) { errlog("Unspecified error while processing query received over DoH"); } } @@ -1303,9 +1289,9 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) anyway, otherwise queries and responses are piling up in our pipes, consuming memory and likely coming up too late after the client has gone away */ while (true) { - DOHUnit *du = nullptr; + DOHUnit *ptr = nullptr; DOHServerConfig* dsc = reinterpret_cast(listener->data); - ssize_t got = read(dsc->dohresponsepair[1], &du, sizeof(du)); + ssize_t got = read(dsc->dohresponsepair[1], &ptr, sizeof(ptr)); if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { @@ -1313,14 +1299,14 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) } return; } - else if (static_cast(got) != sizeof(du)) { - errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(du)); + else if (static_cast(got) != sizeof(ptr)) { + errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(ptr)); return; } + std::unique_ptr du(ptr, DOHUnit::release); if (!du->req) { // it got killed in flight du->self = nullptr; - du->release(); continue; } @@ -1329,15 +1315,14 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) dnsheader* queryDH = reinterpret_cast(du->query.data() + du->proxyProtocolPayloadSize); queryDH->id = du->ids.origID; - auto cpq = std::make_unique(du); du->tcp = true; du->truncated = false; + auto cpq = std::make_unique(std::move(du)); if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { continue; } else { - du->release(); vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP"); continue; } @@ -1351,8 +1336,6 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) } handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true); - - du->release(); } } @@ -1649,38 +1632,37 @@ void dohThread(ClientState* cs) } } -void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state) +void handleUDPResponseForDoH(std::unique_ptr&& du, PacketBuffer&& udpResponse, IDState&& state) { - response = std::move(udpResponse); - ids = std::move(state); + du->response = std::move(udpResponse); + du->ids = std::move(state); - const dnsheader* dh = reinterpret_cast(response.data()); + const dnsheader* dh = reinterpret_cast(du->response.data()); if (!dh->tc) { thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); - DNSResponse dr = makeDNSResponseFromIDState(ids, response); + DNSResponse dr = makeDNSResponseFromIDState(du->ids, du->response); dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - if (!processResponse(response, localRespRuleActions, dr, false, true)) { - release(); + if (!processResponse(du->response, localRespRuleActions, dr, false, true)) { return; } - double udiff = ids.sentTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (https), took %f usec", downstream->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); + double udiff = du->ids.sentTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); - handleResponseSent(ids, udiff, *dr.remote, downstream->remote, response.size(), cleartextDH, downstream->getProtocol()); + handleResponseSent(du->ids, udiff, *dr.remote, du->downstream->remote, du->response.size(), cleartextDH, du->downstream->getProtocol()); ++g_stats.responses; - if (ids.cs) { - ++ids.cs->responses; + if (du->ids.cs) { + ++du->ids.cs->responses; } } else { - truncated = true; + du->truncated = true; } - sendDoHUnitToTheMainThread(this, "DoH response"); + sendDoHUnitToTheMainThread(std::move(du), "DoH response"); } #else /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/doh.hh b/pdns/doh.hh index 16210d9493..65fc6d604a 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -173,6 +173,9 @@ struct DOHFrontend #ifndef HAVE_DNS_OVER_HTTPS struct DOHUnit { + static void release(DOHUnit* ptr) + { + } }; #else /* HAVE_DNS_OVER_HTTPS */ @@ -208,7 +211,12 @@ struct DOHUnit } } - void handleUDPResponse(PacketBuffer&& response, IDState&& state); + static void release(DOHUnit* ptr) + { + if (ptr) { + ptr->release(); + } + } IDState ids; std::string sni; @@ -248,6 +256,8 @@ struct DOHUnit void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType=""); }; +void handleUDPResponseForDoH(std::unique_ptr&&, PacketBuffer&& response, IDState&& state); + #endif /* HAVE_DNS_OVER_HTTPS */ -void handleDOHTimeout(DOHUnit* oldDU); +void handleDOHTimeout(std::unique_ptr&& oldDU);