From: Remi Gacogne Date: Wed, 8 Dec 2021 10:15:08 +0000 (+0100) Subject: dnsdist: Refactoring of the DoH unit handling X-Git-Tag: auth-4.7.0-alpha1~122^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2171e7c7ef360c5646646c6504d29d83a74e18c2;p=thirdparty%2Fpdns.git dnsdist: Refactoring of the DoH unit handling --- diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 4a18e7e7ad..9ce83c6681 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -668,8 +668,9 @@ void responderThread(std::shared_ptr dss) /* don't call processResponse for DOH */ if (du) { #ifdef HAVE_DNS_OVER_HTTPS - // DoH query + // DoH query, we cannot touch du after that du->handleUDPResponse(std::move(response), std::move(*ids)); + du = nullptr; #endif continue; } @@ -1547,6 +1548,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ++ss->reuseds; ++g_stats.downstreamTimeouts; handleDOHTimeout(du); + du = nullptr; } ids->cs = &cs; @@ -1887,6 +1889,7 @@ static void healthChecksThread() } ids.du = nullptr; handleDOHTimeout(oldDU); + oldDU = nullptr; ids.age = 0; dss->reuseds++; --dss->outstanding; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 9986ab2c98..e7684bd5ee 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -223,12 +223,10 @@ struct DOHServerConfig int dohresponsepair[2]{-1,-1}; }; - +/* This internal function sends back the object to the main thread to send a reply. + The caller should NOT release or touch the unit after calling this function */ static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description) { - /* increase the ref counter before sending the pointer */ - du->get(); - static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); ssize_t sent = write(du->rsock, &du, sizeof(du)); if (sent != sizeof(du)) { @@ -245,7 +243,8 @@ static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description) } /* This function is called from other threads than the main DoH one, - instructing it to send a 502 error to the client */ + instructing it to send a 502 error to the client. + It takes ownership of the unit. */ void handleDOHTimeout(DOHUnit* oldDU) { if (oldDU == nullptr) { @@ -256,8 +255,6 @@ void handleDOHTimeout(DOHUnit* oldDU) oldDU->status_code = 502; sendDoHUnitToTheMainThread(oldDU, "DoH timeout"); - - oldDU->release(); } struct DOHConnection @@ -474,7 +471,6 @@ public: } sendDoHUnitToTheMainThread(du, "cross-protocol response"); - du->release(); du = nullptr; } @@ -496,7 +492,6 @@ public: du->ids = std::move(query); du->status_code = 502; sendDoHUnitToTheMainThread(du, "cross-protocol error response"); - du->release(); du = nullptr; } @@ -539,11 +534,10 @@ private: }; /* - this function calls 'return -1' to drop a query without sending it - caller should make sure HTTPS thread hears of that + This function takes ownership of the DOHUnit. We are not in the main DoH thread but in the DoH 'client' thread. */ -static int processDOHQuery(DOHUnit* du) +static void processDOHQuery(DOHUnit* du) { uint16_t queryId = 0; ComboAddress remote; @@ -553,7 +547,9 @@ static int processDOHQuery(DOHUnit* du) // we got closed meanwhile. XXX small race condition here // but we should be fine as long as we don't touch du->req // outside of the main DoH thread - return -1; + du->status_code = 500; + sendDoHUnitToTheMainThread(du, "DoH killed in flight"); + return; } remote = du->ids.origRemote; DOHServerConfig* dsc = du->dsc; @@ -563,7 +559,8 @@ static int processDOHQuery(DOHUnit* du) if (du->query.size() < sizeof(dnsheader)) { ++g_stats.nonCompliantQueries; du->status_code = 400; - return -1; + sendDoHUnitToTheMainThread(du, "DoH non-compliant query"); + return; } ++cs.queries; @@ -581,7 +578,8 @@ static int processDOHQuery(DOHUnit* du) if (!checkQueryHeaders(dh)) { du->status_code = 400; - return -1; // drop + sendDoHUnitToTheMainThread(du, "DoH invalid headers"); + return; } if (dh->qdcount == 0) { @@ -589,9 +587,8 @@ static int processDOHQuery(DOHUnit* du) dh->qr = true; du->response = std::move(du->query); - sendDoHUnitToTheMainThread(du, "DoH self-answered response"); - - return 0; + sendDoHUnitToTheMainThread(du, "DoH empty query"); + return; } queryId = ntohs(dh->id); @@ -609,7 +606,8 @@ static int processDOHQuery(DOHUnit* du) if (result == ProcessQueryResult::Drop) { du->status_code = 403; - return -1; + sendDoHUnitToTheMainThread(du, "DoH dropped query"); + return; } if (result == ProcessQueryResult::SendAnswer) { @@ -617,18 +615,19 @@ static int processDOHQuery(DOHUnit* du) du->response = std::move(du->query); } sendDoHUnitToTheMainThread(du, "DoH self-answered response"); - - return 0; + return; } if (result != ProcessQueryResult::PassToBackend) { du->status_code = 500; - return -1; + sendDoHUnitToTheMainThread(du, "DoH no backend available"); + return; } if (du->downstream == nullptr) { du->status_code = 502; - return -1; + sendDoHUnitToTheMainThread(du, "DoH no backend available"); + return; } if (du->downstream->isTCPOnly()) { @@ -643,18 +642,21 @@ static int processDOHQuery(DOHUnit* du) du->ids.cs = &cs; setIDStateFromDNSQuestion(du->ids, dq, std::move(qname)); - /* this moves du->ids, careful! */ + /* we increment the ref counter because we store a copy in the DoHCrossProtocolQuery object */ du->get(); + /* this moves du->ids, careful! */ auto cpq = std::make_unique(du); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); du->tcp = true; if (du->downstream->passCrossProtocolQuery(std::move(cpq))) { - return 0; + du->release(); + return; } else { - /* do not release du here, it belongs to the DoHCrossProtocolQuery object */ + /* only release du once here, since it also belongs to the DoHCrossProtocolQuery object */ du->status_code = 502; - return -1; + sendDoHUnitToTheMainThread(du, "DoH internal error"); + return; } } @@ -739,7 +741,8 @@ static int processDOHQuery(DOHUnit* du) ++du->downstream->sendErrors; ++g_stats.downstreamSendErrors; du->status_code = 502; - return -1; + sendDoHUnitToTheMainThread(du, "DoH internal error"); + return; } } catch (const std::exception& e) { @@ -751,13 +754,15 @@ static int processDOHQuery(DOHUnit* du) vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), du->downstream->getName()); } - catch(const std::exception& e) { + catch (const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); du->status_code = 500; - return -1; + sendDoHUnitToTheMainThread(du, "DoH internal error"); + return; } - return 0; + du->release(); + return; } /* called when a HTTP response is about to be sent, from the main DoH thread */ @@ -873,8 +878,10 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0); } } - catch(...) { - ptr->release(); + catch (...) { + if (ptr != nullptr) { + ptr->release(); + } } } catch(const std::exception& e) { @@ -1268,13 +1275,9 @@ static void dnsdistclient(int qsock) // we leave existing EDNS in place } - if (processDOHQuery(du) < 0) { - du->status_code = 500; - - sendDoHUnitToTheMainThread(du, "DoH internal error"); - // XXX if we failed to send it to the main thread, now what - will h2o eventually time this out for us - } - du->release(); + /* we transfer the ownership of du to this function */ + processDOHQuery(du); + du = nullptr; } catch(const std::exception& e) { errlog("Error while processing query received over DoH: %s", e.what()); @@ -1678,8 +1681,6 @@ void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state) } sendDoHUnitToTheMainThread(this, "DoH response"); - /* the reference counter has been incremented in sendDoHUnitToTheMainThread */ - release(); } #else /* HAVE_DNS_OVER_HTTPS */