/* This internal function sends back the object to the main thread to send a reply.
The caller should NOT release or touch the unit after calling this function */
-static void sendDoHUnitToTheMainThread(DOHUnit* du, const char* description)
+static void sendDoHUnitToTheMainThread(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du, const char* description)
{
- static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
- ssize_t sent = write(du->rsock, &du, sizeof(du));
- if (sent != sizeof(du)) {
+ auto ptr = du.release();
+ static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+
+ ssize_t sent = write(ptr->rsock, &ptr, sizeof(ptr));
+ if (sent != sizeof(ptr)) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
++g_stats.dohResponsePipeFull;
vinfolog("Unable to pass a %s to the DoH worker thread because the pipe is full", description);
vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, stringerror());
}
- du->release();
+ ptr->release();
}
}
/* This function is called from other threads than the main DoH one,
instructing it to send a 502 error to the client.
It takes ownership of the unit. */
-void handleDOHTimeout(DOHUnit* oldDU)
+void handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& oldDU)
{
if (oldDU == nullptr) {
return;
/* we are about to erase an existing DU */
oldDU->status_code = 502;
- sendDoHUnitToTheMainThread(oldDU, "DoH timeout");
+ sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout");
}
struct DOHConnection
class DoHTCPCrossQuerySender : public TCPQuerySender
{
public:
- DoHTCPCrossQuerySender(DOHUnit* du_): du(du_)
+ DoHTCPCrossQuerySender(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du_): du(std::move(du_))
{
}
- ~DoHTCPCrossQuerySender()
- {
- if (du != nullptr) {
- du->release();
- }
- }
-
bool active() const override
{
return true;
memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
if (!processResponse(du->response, localRespRuleActions, dr, false, false)) {
- du->release();
- du = nullptr;
+ du.reset();
return;
}
++du->ids.cs->responses;
}
- sendDoHUnitToTheMainThread(du, "cross-protocol response");
- du = nullptr;
+ sendDoHUnitToTheMainThread(std::move(du), "cross-protocol response");
}
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
du->ids = std::move(query);
du->status_code = 502;
- sendDoHUnitToTheMainThread(du, "cross-protocol error response");
- du = nullptr;
+ sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response");
}
private:
- DOHUnit* du{nullptr};
+ std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du;
};
class DoHCrossProtocolQuery : public CrossProtocolQuery
{
public:
- DoHCrossProtocolQuery(DOHUnit* du_): du(du_)
+ DoHCrossProtocolQuery(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du_): du(std::move(du_))
{
query = InternalQuery(std::move(du->query), std::move(du->ids));
/* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
proxyProtocolPayloadSize = du->proxyProtocolPayloadSize;
}
- ~DoHCrossProtocolQuery()
+ void handleInternalError()
{
- if (du != nullptr) {
- du->release();
- }
+ du->status_code = 502;
+ sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
}
std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
{
- auto sender = std::make_shared<DoHTCPCrossQuerySender>(du);
- du = nullptr;
+ auto sender = std::make_shared<DoHTCPCrossQuerySender>(std::move(du));
return sender;
}
private:
- DOHUnit* du{nullptr};
+ std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du;
};
/*
- This function takes ownership of the DOHUnit.
We are not in the main DoH thread but in the DoH 'client' thread.
*/
-static void processDOHQuery(DOHUnit* du)
+static void processDOHQuery(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du)
{
uint16_t queryId = 0;
ComboAddress remote;
// but we should be fine as long as we don't touch du->req
// outside of the main DoH thread
du->status_code = 500;
- sendDoHUnitToTheMainThread(du, "DoH killed in flight");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH killed in flight");
return;
}
remote = du->ids.origRemote;
if (du->query.size() < sizeof(dnsheader)) {
++g_stats.nonCompliantQueries;
du->status_code = 400;
- sendDoHUnitToTheMainThread(du, "DoH non-compliant query");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH non-compliant query");
return;
}
if (!checkQueryHeaders(dh)) {
du->status_code = 400;
- sendDoHUnitToTheMainThread(du, "DoH invalid headers");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH invalid headers");
return;
}
dh->qr = true;
du->response = std::move(du->query);
- sendDoHUnitToTheMainThread(du, "DoH empty query");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH empty query");
return;
}
DNSName qname(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength);
DNSQuestion dq(&qname, qtype, qclass, &du->ids.origDest, &du->ids.origRemote, du->query, dnsdist::Protocol::DoH, &queryRealTime);
dq.ednsAdded = du->ids.ednsAdded;
- dq.du = du;
+ /* store the raw pointer */
+ dq.du = du.get();
dq.sni = std::move(du->sni);
auto result = processQuery(dq, cs, holders, du->downstream);
if (result == ProcessQueryResult::Drop) {
du->status_code = 403;
- sendDoHUnitToTheMainThread(du, "DoH dropped query");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH dropped query");
return;
}
if (du->response.empty()) {
du->response = std::move(du->query);
}
- sendDoHUnitToTheMainThread(du, "DoH self-answered response");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH self-answered response");
return;
}
if (result != ProcessQueryResult::PassToBackend) {
du->status_code = 500;
- sendDoHUnitToTheMainThread(du, "DoH no backend available");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available");
return;
}
if (du->downstream == nullptr) {
du->status_code = 502;
- sendDoHUnitToTheMainThread(du, "DoH no backend available");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available");
return;
}
du->ids.cs = &cs;
setIDStateFromDNSQuestion(du->ids, dq, std::move(qname));
- /* we increment the ref counter because we store a copy in the DoHCrossProtocolQuery object */
- du->get();
+ du->tcp = true;
+ std::shared_ptr<DownstreamState>& downstream = du->downstream;
+
/* this moves du->ids, careful! */
- auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
+ auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
- du->tcp = true;
- if (du->downstream->passCrossProtocolQuery(std::move(cpq))) {
- du->release();
+
+ if (downstream->passCrossProtocolQuery(std::move(cpq))) {
return;
}
else {
- /* only release du once here, since it also belongs to the DoHCrossProtocolQuery object */
- du->status_code = 502;
- sendDoHUnitToTheMainThread(du, "DoH internal error");
+ cpq->handleInternalError();
return;
}
}
to handle it because it's about to be overwritten. */
++du->downstream->reuseds;
++g_stats.downstreamTimeouts;
- handleDOHTimeout(oldDU);
+ handleDOHTimeout(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>(oldDU, DOHUnit::release));
}
ids->origFD = 0;
/* increase the ref count since we are about to store the pointer */
du->get();
duRefCountIncremented = true;
- ids->du = du;
+ /* store the raw pointer */
+ ids->du = du.get();
ids->cs = &cs;
ids->origID = htons(queryId);
++du->downstream->sendErrors;
++g_stats.downstreamSendErrors;
du->status_code = 502;
- sendDoHUnitToTheMainThread(du, "DoH internal error");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
return;
}
}
catch (const std::exception& e) {
vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
du->status_code = 500;
- sendDoHUnitToTheMainThread(du, "DoH internal error");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
return;
}
- du->release();
return;
}
for(;;) {
try {
- DOHUnit* du = nullptr;
- ssize_t got = read(qsock, &du, sizeof(du));
+ DOHUnit* ptr = nullptr;
+ ssize_t got = read(qsock, &ptr, sizeof(ptr));
if (got < 0) {
warnlog("Error receiving internal DoH query: %s", strerror(errno));
continue;
}
- else if (static_cast<size_t>(got) < sizeof(du)) {
+ else if (static_cast<size_t>(got) < sizeof(ptr)) {
continue;
}
+ std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du(ptr, DOHUnit::release);
/* we are not in the main DoH thread anymore, so there is a real risk of
a race condition where h2o kills the query while we are processing it,
so we can't touch the content of du->req until we are back into the
if (!du->req) {
// it got killed in flight already
du->self = nullptr;
- du->release();
continue;
}
// we leave existing EDNS in place
}
- /* we transfer the ownership of du to this function */
- processDOHQuery(du);
- du = nullptr;
+ processDOHQuery(std::move(du));
}
- catch(const std::exception& e) {
+ catch (const std::exception& e) {
errlog("Error while processing query received over DoH: %s", e.what());
}
- catch(...) {
+ catch (...) {
errlog("Unspecified error while processing query received over DoH");
}
}
anyway, otherwise queries and responses are piling up in our pipes, consuming
memory and likely coming up too late after the client has gone away */
while (true) {
- DOHUnit *du = nullptr;
+ DOHUnit *ptr = nullptr;
DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(listener->data);
- ssize_t got = read(dsc->dohresponsepair[1], &du, sizeof(du));
+ ssize_t got = read(dsc->dohresponsepair[1], &ptr, sizeof(ptr));
if (got < 0) {
if (errno != EWOULDBLOCK && errno != EAGAIN) {
}
return;
}
- else if (static_cast<size_t>(got) != sizeof(du)) {
- errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(du));
+ else if (static_cast<size_t>(got) != sizeof(ptr)) {
+ errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(ptr));
return;
}
+ std::unique_ptr<DOHUnit, void(*)(DOHUnit*)> du(ptr, DOHUnit::release);
if (!du->req) { // it got killed in flight
du->self = nullptr;
- du->release();
continue;
}
dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data() + du->proxyProtocolPayloadSize);
queryDH->id = du->ids.origID;
- auto cpq = std::make_unique<DoHCrossProtocolQuery>(du);
du->tcp = true;
du->truncated = false;
+ auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du));
if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
continue;
}
else {
- du->release();
vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
continue;
}
}
handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true);
-
- du->release();
}
}
}
}
-void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, IDState&& state)
+void handleUDPResponseForDoH(std::unique_ptr<DOHUnit, void(*)(DOHUnit*)>&& du, PacketBuffer&& udpResponse, IDState&& state)
{
- response = std::move(udpResponse);
- ids = std::move(state);
+ du->response = std::move(udpResponse);
+ du->ids = std::move(state);
- const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data());
+ const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
if (!dh->tc) {
thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
- DNSResponse dr = makeDNSResponseFromIDState(ids, response);
+ DNSResponse dr = makeDNSResponseFromIDState(du->ids, du->response);
dnsheader cleartextDH;
memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
- if (!processResponse(response, localRespRuleActions, dr, false, true)) {
- release();
+ if (!processResponse(du->response, localRespRuleActions, dr, false, true)) {
return;
}
- double udiff = ids.sentTime.udiff();
- vinfolog("Got answer from %s, relayed to %s (https), took %f usec", downstream->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff);
+ double udiff = du->ids.sentTime.udiff();
+ vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
- handleResponseSent(ids, udiff, *dr.remote, downstream->remote, response.size(), cleartextDH, downstream->getProtocol());
+ handleResponseSent(du->ids, udiff, *dr.remote, du->downstream->remote, du->response.size(), cleartextDH, du->downstream->getProtocol());
++g_stats.responses;
- if (ids.cs) {
- ++ids.cs->responses;
+ if (du->ids.cs) {
+ ++du->ids.cs->responses;
}
}
else {
- truncated = true;
+ du->truncated = true;
}
- sendDoHUnitToTheMainThread(this, "DoH response");
+ sendDoHUnitToTheMainThread(std::move(du), "DoH response");
}
#else /* HAVE_DNS_OVER_HTTPS */