DOHFrontend()
{
}
- DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx) :
- d_tlsContext(std::move(tlsCtx))
- {
- }
virtual ~DOHFrontend()
{
#endif
bool d_sendCacheControlHeaders{true};
bool d_trustForwardedForHeader{false};
+ bool d_earlyACLDrop{true};
/* whether we require tue query path to exactly match one of configured ones,
or accept everything below these paths. */
bool d_exactPathMatching{true};
struct timeval now;
gettimeofday(&now, nullptr);
- sender->notifyIOError(std::move(object->query.d_idstate), now);
+ sender->notifyIOError(now, TCPResponse(std::move(object->query)));
return true;
}
setLuaSideEffect();
auto frontend = std::make_shared<DOHFrontend>();
+ if (getOptionalValue<std::string>(vars, "library", frontend->d_library) == 0) {
+#ifdef HAVE_NGHTTP2
+ frontend->d_library = "nghttp2";
+#else /* HAVE_NGHTTP2 */
+ frontend->d_library = "h2o";
+#endif /* HAVE_NGHTTP2 */
+ }
+ if (frontend->d_library == "h2o") {
#ifdef HAVE_LIBH2OEVLOOP
- frontend = std::make_shared<H2ODOHFrontend>();
- frontend->d_library = "h2o";
+ frontend = std::make_shared<H2ODOHFrontend>();
+ frontend->d_library = "h2o";
#else /* HAVE_LIBH2OEVLOOP */
- errlog("DOH bind %s is configured to use libh2o but the library is not available", addr);
- return;
+ errlog("DOH bind %s is configured to use libh2o but the library is not available", addr);
+ return;
#endif /* HAVE_LIBH2OEVLOOP */
+ }
+ else if (frontend->d_library == "nghttp2") {
+#ifndef HAVE_NGHTTP2
+ errlog("DOH bind %s is configured to use nghttp2 but the library is not available", addr);
+ return;
+#endif /* HAVE_NGHTTP2 */
+ }
+ else {
+ errlog("DOH bind %s is configured to use an unknown library ('%s')", addr, frontend->d_library);
+ return;
+ }
+ bool useTLS = true;
if (certFiles && !certFiles->empty()) {
if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
return;
else {
frontend->d_tlsContext.d_addr = ComboAddress(addr, 80);
infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort());
+ useTLS = false;
}
if (urls) {
parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
getOptionalValue<int>(vars, "idleTimeout", frontend->d_idleTimeout);
getOptionalValue<std::string>(vars, "serverTokens", frontend->d_serverTokens);
+ getOptionalValue<std::string>(vars, "provider", frontend->d_tlsContext.d_provider);
+ boost::algorithm::to_lower(frontend->d_tlsContext.d_provider);
LuaAssociativeTable<std::string> customResponseHeaders;
if (getOptionalValue<decltype(customResponseHeaders)>(vars, "customResponseHeaders", customResponseHeaders) > 0) {
getOptionalValue<bool>(vars, "sendCacheControlHeaders", frontend->d_sendCacheControlHeaders);
getOptionalValue<bool>(vars, "keepIncomingHeaders", frontend->d_keepIncomingHeaders);
getOptionalValue<bool>(vars, "trustForwardedForHeader", frontend->d_trustForwardedForHeader);
+ getOptionalValue<bool>(vars, "earlyACLDrop", frontend->d_earlyACLDrop);
getOptionalValue<int>(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize);
getOptionalValue<bool>(vars, "exactPathMatching", frontend->d_exactPathMatching);
checkAllParametersConsumed("addDOHLocal", vars);
}
+
+ if (useTLS && frontend->d_library == "nghttp2") {
+ if (!frontend->d_tlsContext.d_provider.empty()) {
+ vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext.d_provider);
+ }
+ else {
+#ifdef HAVE_LIBSSL
+ const std::string provider("openssl");
+#else
+ const std::string provider("gnutls");
+#endif
+ vinfolog("Loading default TLS provider '%s'", provider);
+ }
+ }
+
g_dohlocals.push_back(frontend);
auto cs = std::make_unique<ClientState>(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
cs->dohFrontend = frontend;
}
else {
#ifdef HAVE_LIBSSL
- vinfolog("Loading default TLS provider 'openssl'");
+ const std::string provider("openssl");
#else
- vinfolog("Loading default TLS provider 'gnutls'");
+ const std::string provider("gnutls");
#endif
+ vinfolog("Loading default TLS provider '%s'", provider);
}
// only works pre-startup, so no sync necessary
auto cs = std::make_unique<ClientState>(frontend->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
#include "dnsdist.hh"
#include "dnsdist-concurrent-connections.hh"
#include "dnsdist-ecs.hh"
+#include "dnsdist-nghttp2-in.hh"
#include "dnsdist-proxy-protocol.hh"
#include "dnsdist-rings.hh"
#include "dnsdist-tcp.hh"
d_handler.close();
}
+dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const
+{
+ if (d_ci.cs->dohFrontend) {
+ return dnsdist::Protocol::DoH;
+ }
+ if (d_handler.isTLS()) {
+ return dnsdist::Protocol::DoT;
+ }
+ return dnsdist::Protocol::DoTCP;
+}
+
size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
{
return t_downstreamTCPConnectionsManager.clear();
TCPResponse resp = std::move(state->d_queuedResponses.front());
state->d_queuedResponses.pop_front();
state->d_state = IncomingTCPConnectionState::State::idle;
- result = state->sendResponse(state, now, std::move(resp));
+ result = state->sendResponse(now, std::move(resp));
if (result != IOState::Done) {
return result;
}
return IOState::Done;
}
-static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, TCPResponse& currentResponse)
+void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse)
{
if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
return;
}
- --state->d_currentQueriesCount;
+ --d_currentQueriesCount;
const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
if (currentResponse.d_idstate.selfGenerated == false && ds) {
const auto& ids = currentResponse.d_idstate;
double udiff = ids.queryRealTime.udiff();
- vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff);
+ vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff);
auto backendProtocol = ds->getProtocol();
- if (backendProtocol == dnsdist::Protocol::DoUDP) {
+ if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) {
backendProtocol = dnsdist::Protocol::DoTCP;
}
- ::handleResponseSent(ids, udiff, state->d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true);
+ ::handleResponseSent(ids, udiff, d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true);
} else {
const auto& ids = currentResponse.d_idstate;
- ::handleResponseSent(ids, 0., state->d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
+ ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
}
currentResponse.d_buffer.clear();
return false;
}
- if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
+ // for DoH, this is already handled by the underlying library
+ if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn);
return false;
}
}
/* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
-IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
+IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
{
- state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+ d_state = IncomingTCPConnectionState::State::sendingResponse;
uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
that could occur if we had to deal with the size during the processing,
especially alignment issues */
response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
- state->d_currentPos = 0;
- state->d_currentResponse = std::move(response);
+ d_currentPos = 0;
+ d_currentResponse = std::move(response);
try {
- auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+ auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
if (iostate == IOState::Done) {
DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
- handleResponseSent(state, state->d_currentResponse);
+ handleResponseSent(d_currentResponse);
return iostate;
} else {
- state->d_lastIOBlocked = true;
+ d_lastIOBlocked = true;
DEBUGLOG("partial write");
return iostate;
}
}
catch (const std::exception& e) {
- vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
+ vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
DEBUGLOG("Closing TCP client connection: "<<e.what());
- ++state->d_ci.cs->tcpDiedSendingResponse;
+ ++d_ci.cs->tcpDiedSendingResponse;
- state->terminateClientConnection();
+ terminateClientConnection();
return IOState::Done;
}
if (state->active()) {
/* and now we restart our own I/O state machine */
- struct timeval now;
- gettimeofday(&now, nullptr);
- handleIO(state, now);
+ state->handleIO();
}
else {
/* we were only waiting for the engine to come back,
try {
auto& ids = response.d_idstate;
unsigned int qnameWireLength;
- if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) {
+ std::shared_ptr<DownstreamState> ds = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
+ if (!ds || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, ds, qnameWireLength)) {
state->terminateClientConnection();
return;
}
- if (response.d_connection->getDS()) {
- ++response.d_connection->getDS()->responses;
+ if (ds) {
+ ++ds->responses;
}
- DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS());
+ DNSResponse dr(ids, response.d_buffer, ds);
dr.d_incomingTCPState = state;
memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
public:
TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
{
- proxyProtocolPayloadSize = 0;
}
~TCPCrossProtocolQuery()
std::shared_ptr<IncomingTCPConnectionState> d_sender;
};
+std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds)
+{
+ return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), ds, shared_from_this());
+}
+
std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq)
{
auto state = dq.getIncomingTCPState();
}
}
-static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID)
{
- if (state->d_querySize < sizeof(dnsheader)) {
+ auto query = std::move(queryIn);
+ if (query.size() < sizeof(dnsheader)) {
++dnsdist::metrics::g_stats.nonCompliantQueries;
- ++state->d_ci.cs->nonCompliantQueries;
- state->terminateClientConnection();
- return;
+ ++d_ci.cs->nonCompliantQueries;
+ return QueryProcessingResult::TooSmall;
}
- ++state->d_queriesCount;
- ++state->d_ci.cs->queries;
+ ++d_queriesCount;
+ ++d_ci.cs->queries;
++dnsdist::metrics::g_stats.queries;
- if (state->d_handler.isTLS()) {
- auto tlsVersion = state->d_handler.getTLSVersion();
+ if (d_handler.isTLS()) {
+ auto tlsVersion = d_handler.getTLSVersion();
switch (tlsVersion) {
case LibsslTLSVersion::TLS10:
- ++state->d_ci.cs->tls10queries;
+ ++d_ci.cs->tls10queries;
break;
case LibsslTLSVersion::TLS11:
- ++state->d_ci.cs->tls11queries;
+ ++d_ci.cs->tls11queries;
break;
case LibsslTLSVersion::TLS12:
- ++state->d_ci.cs->tls12queries;
+ ++d_ci.cs->tls12queries;
break;
case LibsslTLSVersion::TLS13:
- ++state->d_ci.cs->tls13queries;
+ ++d_ci.cs->tls13queries;
break;
default:
- ++state->d_ci.cs->tlsUnknownqueries;
+ ++d_ci.cs->tlsUnknownqueries;
}
}
+ auto state = shared_from_this();
InternalQueryState ids;
- ids.origDest = state->d_proxiedDestination;
- ids.origRemote = state->d_proxiedRemote;
- ids.cs = state->d_ci.cs;
+ ids.origDest = d_proxiedDestination;
+ ids.origRemote = d_proxiedRemote;
+ ids.cs = d_ci.cs;
ids.queryRealTime.start();
+ if (streamID) {
+ ids.d_streamID = *streamID;
+ }
- auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
+ auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
if (dnsCryptResponse) {
TCPResponse response;
- state->d_state = IncomingTCPConnectionState::State::idle;
- ++state->d_currentQueriesCount;
- state->queueResponse(state, now, std::move(response));
- return;
+ d_state = IncomingTCPConnectionState::State::idle;
+ ++d_currentQueriesCount;
+ queueResponse(state, now, std::move(response));
+ return QueryProcessingResult::SelfAnswered;
}
{
/* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
- auto* dh = reinterpret_cast<dnsheader*>(state->d_buffer.data());
- if (!checkQueryHeaders(dh, *state->d_ci.cs)) {
- state->terminateClientConnection();
- return;
+ auto* dh = reinterpret_cast<dnsheader*>(query.data());
+ if (!checkQueryHeaders(dh, *d_ci.cs)) {
+ return QueryProcessingResult::InvalidHeaders;
}
if (dh->qdcount == 0) {
dh->rcode = RCode::NotImp;
dh->qr = true;
response.d_idstate.selfGenerated = true;
- response.d_buffer = std::move(state->d_buffer);
- state->d_state = IncomingTCPConnectionState::State::idle;
- ++state->d_currentQueriesCount;
- state->queueResponse(state, now, std::move(response));
- return;
+ response.d_buffer = std::move(query);
+ d_state = IncomingTCPConnectionState::State::idle;
+ ++d_currentQueriesCount;
+ queueResponse(state, now, std::move(response));
+ return QueryProcessingResult::Empty;
}
}
- ids.qname = DNSName(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
- ids.protocol = dnsdist::Protocol::DoTCP;
+ ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
+ ids.protocol = getProtocol();
if (ids.dnsCryptQuery) {
ids.protocol = dnsdist::Protocol::DNSCryptTCP;
}
- else if (state->d_handler.isTLS()) {
- ids.protocol = dnsdist::Protocol::DoT;
- }
- DNSQuestion dq(ids, state->d_buffer);
+ DNSQuestion dq(ids, query);
const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
ids.origFlags = *flags;
dq.d_incomingTCPState = state;
- dq.sni = state->d_handler.getServerNameIndication();
+ dq.sni = d_handler.getServerNameIndication();
- if (state->d_proxyProtocolValues) {
+ if (d_proxyProtocolValues) {
/* we need to copy them, because the next queries received on that connection will
need to get the _unaltered_ values */
- dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues);
+ dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues);
}
if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) {
dq.ids.skipCache = true;
}
- std::shared_ptr<DownstreamState> ds;
- auto result = processQuery(dq, state->d_threadData.holders, ds);
+ if (forwardViaUDPFirst()) {
+ // if there was no EDNS, we add it with a large buffer size
+ // so we can use UDP to talk to the backend.
+ auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(query.data()));
+ if (!dh->arcount) {
+ if (addEDNS(query, 4096, false, 4096, 0)) {
+ dq.ids.ednsAdded = true;
+ }
+ }
+ }
- if (result == ProcessQueryResult::Drop) {
- state->terminateClientConnection();
- return;
+ if (streamID) {
+ auto unit = getDOHUnit(*streamID);
+ dq.ids.du = std::move(unit);
}
- else if (result == ProcessQueryResult::Asynchronous) {
+
+ std::shared_ptr<DownstreamState> ds;
+ auto result = processQuery(dq, d_threadData.holders, ds);
+
+ if (result == ProcessQueryResult::Asynchronous) {
/* we are done for now */
- ++state->d_currentQueriesCount;
- return;
+ ++d_currentQueriesCount;
+ return QueryProcessingResult::Asynchronous;
+ }
+
+ if (streamID) {
+ restoreDOHUnit(std::move(dq.ids.du));
+ }
+
+ if (result == ProcessQueryResult::Drop) {
+ return QueryProcessingResult::Dropped;
}
// the buffer might have been invalidated by now
- const dnsheader* dh = dq.getHeader();
+ uint16_t queryID;
+ {
+ const dnsheader* dh = dq.getHeader();
+ queryID = dh->id;
+ }
+
if (result == ProcessQueryResult::SendAnswer) {
TCPResponse response;
- memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+ {
+ const dnsheader* dh = dq.getHeader();
+ memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+ }
response.d_idstate = std::move(ids);
- response.d_idstate.origID = dh->id;
+ response.d_idstate.origID = queryID;
response.d_idstate.selfGenerated = true;
- response.d_idstate.cs = state->d_ci.cs;
- response.d_buffer = std::move(state->d_buffer);
+ response.d_idstate.cs = d_ci.cs;
+ response.d_buffer = std::move(query);
- state->d_state = IncomingTCPConnectionState::State::idle;
- ++state->d_currentQueriesCount;
- state->queueResponse(state, now, std::move(response));
- return;
+ d_state = IncomingTCPConnectionState::State::idle;
+ ++d_currentQueriesCount;
+ queueResponse(state, now, std::move(response));
+ return QueryProcessingResult::SelfAnswered;
}
if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
- state->terminateClientConnection();
- return;
+ return QueryProcessingResult::NoBackend;
}
- dq.ids.origID = dh->id;
+ dq.ids.origID = queryID;
- ++state->d_currentQueriesCount;
+ ++d_currentQueriesCount;
std::string proxyProtocolPayload;
if (ds->isDoH()) {
- vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr());
+ vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), ds->getNameWithAddr());
/* we need to do this _before_ creating the cross protocol query because
after that the buffer will have been moved */
proxyProtocolPayload = getProxyProtocolPayload(dq);
}
- auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
+ auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), ds, state);
cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
ds->passCrossProtocolQuery(std::move(cpq));
- return;
+ return QueryProcessingResult::Forwarded;
+ }
+ else if (!ds->isTCPOnly() && forwardViaUDPFirst()) {
+ auto unit = getDOHUnit(*streamID);
+ dq.ids.du = std::move(unit);
+ if (assignOutgoingUDPQueryToBackend(ds, queryID, dq, query)) {
+ return QueryProcessingResult::Forwarded;
+ }
+ restoreDOHUnit(std::move(dq.ids.du));
+ // fallback to the normal flow
}
- prependSizeToTCPQuery(state->d_buffer, 0);
+ prependSizeToTCPQuery(query, 0);
- auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
+ auto downstreamConnection = getDownstreamConnection(ds, dq.proxyProtocolValues, now);
if (ds->d_config.useProxyProtocol) {
/* if we ever sent a TLV over a connection, we can never go back */
- if (!state->d_proxyProtocolPayloadHasTLV) {
- state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
+ if (!d_proxyProtocolPayloadHasTLV) {
+ d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
}
proxyProtocolPayload = getProxyProtocolPayload(dq);
downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
}
- TCPQuery query(std::move(state->d_buffer), std::move(ids));
- query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
+ TCPQuery tcpquery(std::move(query), std::move(ids));
+ tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
- vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr());
+ vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), ds->getNameWithAddr());
std::shared_ptr<TCPQuerySender> incoming = state;
- downstreamConnection->queueQuery(incoming, std::move(query));
+ downstreamConnection->queueQuery(incoming, std::move(tcpquery));
+ return QueryProcessingResult::Forwarded;
}
void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
}
- struct timeval now;
- gettimeofday(&now, nullptr);
- handleIO(conn, now);
+ conn->handleIO();
}
-void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now)
+{
+ if (d_handler.isTLS()) {
+ if (!d_handler.hasTLSSessionBeenResumed()) {
+ ++d_ci.cs->tlsNewSessions;
+ }
+ else {
+ ++d_ci.cs->tlsResumptions;
+ }
+ if (d_handler.getResumedFromInactiveTicketKey()) {
+ ++d_ci.cs->tlsInactiveTicketKey;
+ }
+ if (d_handler.getUnknownTicketKey()) {
+ ++d_ci.cs->tlsUnknownTicketKey;
+ }
+ }
+
+ d_handshakeDoneTime = now;
+}
+
+IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload()
+{
+ do {
+ DEBUGLOG("reading proxy protocol header");
+ auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed);
+ if (iostate == IOState::Done) {
+ d_buffer.resize(d_currentPos);
+ ssize_t remaining = isProxyHeaderComplete(d_buffer);
+ if (remaining == 0) {
+ vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort());
+ ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
+ return ProxyProtocolResult::Error;
+ }
+ else if (remaining < 0) {
+ d_proxyProtocolNeed += -remaining;
+ d_buffer.resize(d_currentPos + d_proxyProtocolNeed);
+ /* we need to keep reading, since we might have buffered data */
+ }
+ else {
+ /* proxy header received */
+ std::vector<ProxyProtocolValue> proxyProtocolValues;
+ if (!handleProxyProtocol(d_ci.remote, true, *d_threadData.holders.acl, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) {
+ vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort());
+ return ProxyProtocolResult::Error;
+ }
+
+ if (!proxyProtocolValues.empty()) {
+ d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+ }
+
+ return ProxyProtocolResult::Done;
+ }
+ }
+ else {
+ d_lastIOBlocked = true;
+ }
+ }
+ while (active() && !d_lastIOBlocked);
+
+ return ProxyProtocolResult::Reading;
+}
+
+void IncomingTCPConnectionState::handleIO()
{
// why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
// even though the underlying socket is not ready, so we need to actually ask for the data first
IOState iostate = IOState::Done;
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
do {
iostate = IOState::Done;
- IOStateGuard ioGuard(state->d_ioState);
+ IOStateGuard ioGuard(d_ioState);
- if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+ if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
// will be handled by the ioGuard
//handleNewIOState(state, IOState::Done, fd, handleIOCallback);
return;
}
- state->d_lastIOBlocked = false;
+ d_lastIOBlocked = false;
try {
- if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+ if (d_state == IncomingTCPConnectionState::State::doingHandshake) {
DEBUGLOG("doing handshake");
- iostate = state->d_handler.tryHandshake();
+ iostate = d_handler.tryHandshake();
if (iostate == IOState::Done) {
DEBUGLOG("handshake done");
- if (state->d_handler.isTLS()) {
- if (!state->d_handler.hasTLSSessionBeenResumed()) {
- ++state->d_ci.cs->tlsNewSessions;
- }
- else {
- ++state->d_ci.cs->tlsResumptions;
- }
- if (state->d_handler.getResumedFromInactiveTicketKey()) {
- ++state->d_ci.cs->tlsInactiveTicketKey;
- }
- if (state->d_handler.getUnknownTicketKey()) {
- ++state->d_ci.cs->tlsUnknownTicketKey;
- }
- }
+ handleHandshakeDone(now);
- state->d_handshakeDoneTime = now;
- if (expectProxyProtocolFrom(state->d_ci.remote)) {
- state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
- state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
- state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+ if (expectProxyProtocolFrom(d_ci.remote)) {
+ d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
+ d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+ d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
}
else {
- state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+ d_state = IncomingTCPConnectionState::State::readingQuerySize;
}
}
else {
- state->d_lastIOBlocked = true;
+ d_lastIOBlocked = true;
}
}
- if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
- do {
- DEBUGLOG("reading proxy protocol header");
- iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed);
- if (iostate == IOState::Done) {
- state->d_buffer.resize(state->d_currentPos);
- ssize_t remaining = isProxyHeaderComplete(state->d_buffer);
- if (remaining == 0) {
- vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort());
- ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
- break;
- }
- else if (remaining < 0) {
- state->d_proxyProtocolNeed += -remaining;
- state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed);
- /* we need to keep reading, since we might have buffered data */
- iostate = IOState::NeedRead;
- }
- else {
- /* proxy header received */
- std::vector<ProxyProtocolValue> proxyProtocolValues;
- if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) {
- vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort());
- break;
- }
-
- if (!proxyProtocolValues.empty()) {
- state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
- }
-
- state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
- state->d_buffer.resize(sizeof(uint16_t));
- state->d_currentPos = 0;
- state->d_proxyProtocolNeed = 0;
- break;
- }
- }
- else {
- state->d_lastIOBlocked = true;
- }
+ if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+ auto status = handleProxyProtocolPayload();
+ if (status == ProxyProtocolResult::Done) {
+ d_state = IncomingTCPConnectionState::State::readingQuerySize;
+ d_buffer.resize(sizeof(uint16_t));
+ d_currentPos = 0;
+ d_proxyProtocolNeed = 0;
+ }
+ else if (status == ProxyProtocolResult::Error) {
+ iostate = IOState::Done;
+ }
+ else {
+ iostate = IOState::NeedRead;
}
- while (state->active() && !state->d_lastIOBlocked);
}
- if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
- state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
+ if (!d_lastIOBlocked && (d_state == IncomingTCPConnectionState::State::waitingForQuery ||
+ d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
DEBUGLOG("reading query size");
- state->d_buffer.resize(sizeof(uint16_t));
- iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
- if (state->d_currentPos > 0) {
+ d_buffer.resize(sizeof(uint16_t));
+ iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
+ if (d_currentPos > 0) {
/* if we got at least one byte, we can't go around sending responses */
- state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+ d_state = IncomingTCPConnectionState::State::readingQuerySize;
}
if (iostate == IOState::Done) {
DEBUGLOG("query size received");
- state->d_state = IncomingTCPConnectionState::State::readingQuery;
- state->d_querySizeReadTime = now;
- if (state->d_queriesCount == 0) {
- state->d_firstQuerySizeReadTime = now;
+ d_state = IncomingTCPConnectionState::State::readingQuery;
+ d_querySizeReadTime = now;
+ if (d_queriesCount == 0) {
+ d_firstQuerySizeReadTime = now;
}
- state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
- if (state->d_querySize < sizeof(dnsheader)) {
+ d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1);
+ if (d_querySize < sizeof(dnsheader)) {
/* go away */
- state->terminateClientConnection();
+ terminateClientConnection();
return;
}
/* allocate a bit more memory to be able to spoof the content, get an answer from the cache
or to add ECS without allocating a new buffer */
- state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
- state->d_currentPos = 0;
+ d_buffer.resize(std::max(d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
+ d_currentPos = 0;
}
else {
- state->d_lastIOBlocked = true;
+ d_lastIOBlocked = true;
}
}
- if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+ if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingQuery) {
DEBUGLOG("reading query");
- iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+ iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
if (iostate == IOState::Done) {
DEBUGLOG("query received");
- state->d_buffer.resize(state->d_querySize);
+ d_buffer.resize(d_querySize);
+
+ d_state = IncomingTCPConnectionState::State::idle;
+ auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
+ switch (processingResult) {
+ case QueryProcessingResult::TooSmall:
+ /* fall-through */
+ case QueryProcessingResult::InvalidHeaders:
+ /* fall-through */
+ case QueryProcessingResult::Dropped:
+ /* fall-through */
+ case QueryProcessingResult::NoBackend:
+ terminateClientConnection();
+ break;
+ default:
+ break;
+ }
- state->d_state = IncomingTCPConnectionState::State::idle;
- handleQuery(state, now);
/* the state might have been updated in the meantime, we don't want to override it
in that case */
- if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) {
- if (state->d_ioState->isWaitingForRead()) {
+ if (active() && d_state != IncomingTCPConnectionState::State::idle) {
+ if (d_ioState->isWaitingForRead()) {
iostate = IOState::NeedRead;
}
- else if (state->d_ioState->isWaitingForWrite()) {
+ else if (d_ioState->isWaitingForWrite()) {
iostate = IOState::NeedWrite;
}
else {
}
}
else {
- state->d_lastIOBlocked = true;
+ d_lastIOBlocked = true;
}
}
- if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+ if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::sendingResponse) {
DEBUGLOG("sending response");
- iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+ iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
if (iostate == IOState::Done) {
DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
- handleResponseSent(state, state->d_currentResponse);
- state->d_state = IncomingTCPConnectionState::State::idle;
+ handleResponseSent(d_currentResponse);
+ d_state = IncomingTCPConnectionState::State::idle;
}
else {
- state->d_lastIOBlocked = true;
+ d_lastIOBlocked = true;
}
}
- if (state->active() &&
- !state->d_lastIOBlocked &&
+ if (active() &&
+ !d_lastIOBlocked &&
iostate == IOState::Done &&
- (state->d_state == IncomingTCPConnectionState::State::idle ||
- state->d_state == IncomingTCPConnectionState::State::waitingForQuery))
+ (d_state == IncomingTCPConnectionState::State::idle ||
+ d_state == IncomingTCPConnectionState::State::waitingForQuery))
{
// try sending queued responses
DEBUGLOG("send responses, if any");
+ auto state = shared_from_this();
iostate = sendQueuedResponses(state, now);
- if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) {
+ if (!d_lastIOBlocked && active() && iostate == IOState::Done) {
// if the query has been passed to a backend, or dropped, and the responses have been sent,
// we can start reading again
- if (state->canAcceptNewQueries(now)) {
- state->resetForNewQuery();
+ if (canAcceptNewQueries(now)) {
+ resetForNewQuery();
iostate = IOState::NeedRead;
}
else {
- state->d_state = IncomingTCPConnectionState::State::idle;
+ d_state = IncomingTCPConnectionState::State::idle;
iostate = IOState::Done;
}
}
}
- if (state->d_state != IncomingTCPConnectionState::State::idle &&
- state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
- state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
- state->d_state != IncomingTCPConnectionState::State::waitingForQuery &&
- state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
- state->d_state != IncomingTCPConnectionState::State::readingQuery &&
- state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
- vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
+ if (d_state != IncomingTCPConnectionState::State::idle &&
+ d_state != IncomingTCPConnectionState::State::doingHandshake &&
+ d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
+ d_state != IncomingTCPConnectionState::State::waitingForQuery &&
+ d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+ d_state != IncomingTCPConnectionState::State::readingQuery &&
+ d_state != IncomingTCPConnectionState::State::sendingResponse) {
+ vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
}
}
catch (const std::exception& e) {
but it might also be a real IO error or something else.
Let's just drop the connection
*/
- if (state->d_state == IncomingTCPConnectionState::State::idle ||
- state->d_state == IncomingTCPConnectionState::State::waitingForQuery) {
+ if (d_state == IncomingTCPConnectionState::State::idle ||
+ d_state == IncomingTCPConnectionState::State::waitingForQuery) {
/* no need to increase any counters in that case, the client is simply done with us */
}
- else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
- state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
- state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
- state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
- state->d_state == IncomingTCPConnectionState::State::readingQuery) {
- ++state->d_ci.cs->tcpDiedReadingQuery;
+ else if (d_state == IncomingTCPConnectionState::State::doingHandshake ||
+ d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
+ d_state == IncomingTCPConnectionState::State::waitingForQuery ||
+ d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+ d_state == IncomingTCPConnectionState::State::readingQuery) {
+ ++d_ci.cs->tcpDiedReadingQuery;
}
- else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+ else if (d_state == IncomingTCPConnectionState::State::sendingResponse) {
/* unlikely to happen here, the exception should be handled in sendResponse() */
- ++state->d_ci.cs->tcpDiedSendingResponse;
+ ++d_ci.cs->tcpDiedSendingResponse;
}
- if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) {
+ if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) {
DEBUGLOG("Got an exception while handling TCP query: "<<e.what());
- vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+ vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), e.what());
}
else {
- vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
+ vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
DEBUGLOG("Closing TCP client connection: "<<e.what());
}
/* remove this FD from the IO multiplexer */
- state->terminateClientConnection();
+ terminateClientConnection();
}
- if (!state->active()) {
+ if (!active()) {
DEBUGLOG("state is no longer active");
return;
}
+ auto state = shared_from_this();
if (iostate == IOState::Done) {
- state->d_ioState->update(iostate, handleIOCallback, state);
+ d_ioState->update(iostate, handleIOCallback, state);
}
else {
updateIO(state, iostate, now);
}
ioGuard.release();
}
- while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked);
+ while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked);
}
-void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now)
+void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response)
{
if (std::this_thread::get_id() != d_creatorThreadID) {
/* empty buffer will signal an IO error */
- TCPResponse response(PacketBuffer(), std::move(query), nullptr, nullptr);
+ response.d_buffer.clear();
handleCrossProtocolResponse(now, std::move(response));
return;
}
struct timeval now;
gettimeofday(&now, nullptr);
- auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+
+ if (citmp->cs->dohFrontend) {
+#ifdef HAVE_NGHTTP2
+ auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now);
+ state->handleIO();
+#endif /* HAVE_NGHTTP2 */
+ }
+ else {
+ auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
+ state->handleIO();
+ }
}
static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
auto query = std::move(cpq->query);
auto downstreamServer = std::move(cpq->downstream);
- auto proxyProtocolPayloadSize = cpq->proxyProtocolPayloadSize;
try {
auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
- prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
- query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
+ prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize);
vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
downstream->queueQuery(tqs, std::move(query));
}
catch (...) {
- tqs->notifyIOError(std::move(query.d_idstate), now);
+ tqs->notifyIOError(now, std::move(query));
}
}
try {
if (response.d_response.d_buffer.empty()) {
- response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now);
+ response.d_state->notifyIOError(response.d_now, std::move(response.d_response));
}
else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
{
auto& cs = param.cs;
auto& acl = param.acl;
- int socket = param.socket;
+ const bool checkACL = !cs.dohFrontend || (!cs.dohFrontend->d_trustForwardedForHeader && cs.dohFrontend->d_earlyACLDrop);
+ const int socket = param.socket;
bool tcpClientCountIncremented = false;
ComboAddress remote;
remote.sin4.sin_family = param.local.sin4.sin_family;
throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
}
- if (!acl->match(remote)) {
+ if (checkACL && !acl->match(remote)) {
++dnsdist::metrics::g_stats.aclDrops;
vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
return;
vinfolog("Got TCP connection from %s", remote.toStringWithPort());
ci.remote = remote;
+
if (threadData == nullptr) {
if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) {
if (tcpClientCountIncremented) {
else {
struct timeval now;
gettimeofday(&now, nullptr);
- auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+
+ if (ci.cs->dohFrontend) {
+#ifdef HAVE_NGHTTP2
+ auto state = std::make_shared<IncomingHTTP2Connection>(std::move(ci), *threadData, now);
+ state->handleIO();
+#endif /* HAVE_NGHTTP2 */
+ }
+ else {
+ auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
+ state->handleIO();
+ }
}
}
catch (const std::exception& e) {
return handleResponse(now, std::move(response));
}
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+ void notifyIOError(const struct timeval&, TCPResponse&&) override
{
// nothing to do
}
cout<<"gnutls";
#ifdef HAVE_LIBSSL
cout<<" ";
-#endif /* HAVE_LIBSSL */
+#endif
#endif /* HAVE_GNUTLS */
#ifdef HAVE_LIBSSL
cout<<"openssl";
-#endif /* HAVE_LIBSSL */
+#endif
cout<<") ";
#endif /* HAVE_DNS_OVER_TLS */
#ifdef HAVE_DNS_OVER_HTTPS
cout<<"dns-over-https(";
#ifdef HAVE_LIBH2OEVLOOP
cout<<"h2o";
+#ifdef HAVE_NGHTTP2
+ cout<<" ";
+#endif
#endif /* HAVE_LIBH2OEVLOOP */
+#ifdef HAVE_NGHTTP2
+ cout<<"nghttp2";
+#endif
cout<<") ";
#endif /* HAVE_DNS_OVER_HTTPS */
#ifdef HAVE_DNSCRYPT
#ifdef HAVE_LMDB
cout<<"lmdb ";
#endif
-#ifdef HAVE_NGHTTP2
- cout<<"outgoing-dns-over-https(nghttp2) ";
-#endif
#ifndef DISABLE_PROTOBUF
cout<<"protobuf ";
#endif
std::vector<ClientState*> tcpStates;
std::vector<ClientState*> udpStates;
- for(auto& cs : g_frontends) {
- if (cs->dohFrontend != nullptr) {
+ for (auto& cs : g_frontends) {
+ if (cs->dohFrontend != nullptr && cs->dohFrontend->d_library == "h2o") {
#ifdef HAVE_DNS_OVER_HTTPS
#ifdef HAVE_LIBH2OEVLOOP
std::thread t1(dohThread, cs.get());
AM_CPPFLAGS += $(LIBSSL_CFLAGS)
endif
+if HAVE_GNUTLS
+AM_CPPFLAGS += $(GNUTLS_CFLAGS)
+endif
+
if HAVE_LIBH2OEVLOOP
AM_CPPFLAGS += $(LIBH2OEVLOOP_CFLAGS)
endif
dnsdist-lua.cc dnsdist-lua.hh \
dnsdist-mac-address.cc dnsdist-mac-address.hh \
dnsdist-metrics.cc dnsdist-metrics.hh \
+ dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \
dnsdist-nghttp2.cc dnsdist-nghttp2.hh \
dnsdist-prometheus.hh \
dnsdist-protobuf.cc dnsdist-protobuf.hh \
dnsdist-lua-vars.cc \
dnsdist-mac-address.cc dnsdist-mac-address.hh \
dnsdist-metrics.cc dnsdist-metrics.hh \
+ dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \
dnsdist-nghttp2.cc dnsdist-nghttp2.hh \
dnsdist-protocols.cc dnsdist-protocols.hh \
dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
if HAVE_DNS_OVER_HTTPS
+if HAVE_GNUTLS
+dnsdist_LDADD += -lgnutls
+endif
+
if HAVE_LIBH2OEVLOOP
dnsdist_LDADD += $(LIBH2OEVLOOP_LIBS)
endif
AM_CONDITIONAL([HAVE_LIBH2OEVLOOP], [false])
AM_CONDITIONAL([HAVE_LIBSSL], [false])
AM_CONDITIONAL([HAVE_LMDB], [false])
+AM_CONDITIONAL([HAVE_NGHTTP2], [false])
PDNS_CHECK_LIBCRYPTO
AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno"], [
PDNS_WITH_LIBSSL
+ PDNS_WITH_GNUTLS
])
AS_IF([test "x$enable_dns_over_tls" != "xno"], [
- PDNS_WITH_GNUTLS
-
AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
])
])
AS_IF([test "x$enable_dns_over_https" != "xno"], [
+ PDNS_WITH_NGHTTP2
PDNS_WITH_LIBH2OEVLOOP
- AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1"], [
- AC_MSG_ERROR([DNS over HTTPS support requested but libh2o-evloop was not found])
+ AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1" -a "x$HAVE_NGHTTP2" != "x1" ], [
+ AC_MSG_ERROR([DNS over HTTPS support requested but neither libh2o-evloop nor nghttp2 was not found])
])
- AS_IF([test "x$HAVE_LIBSSL" != "x1"], [
- AC_MSG_ERROR([DNS over HTTPS support requested but OpenSSL was not found])
+ AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
+ AC_MSG_ERROR([DNS over HTTPS support requested but neither GnuTLS nor OpenSSL are available])
])
])
-PDNS_WITH_NGHTTP2
-
DNSDIST_WITH_CDB
PDNS_CHECK_LMDB
PDNS_ENABLE_IPCIPHER
vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec);
auto sender = query->getTCPQuerySender();
if (sender) {
- sender->notifyIOError(std::move(query->query.d_idstate), now);
+ TCPResponse tresponse(std::move(query->query));
+ sender->notifyIOError(now, std::move(tresponse));
}
}
else {
throw std::runtime_error("Unexpected XFR reponse to a health check query");
}
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+ void notifyIOError(const struct timeval& now, TCPResponse&&) override
{
++d_data->d_ds->d_healthCheckMetrics.d_networkErrors;
d_data->d_ds->submitHealthCheckResult(d_data->d_initial, false);
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
#include "dnsdist-internal-queries.hh"
+#include "dnsdist-nghttp2-in.hh"
#include "dnsdist-tcp.hh"
#include "doh.hh"
}
#ifdef HAVE_DNS_OVER_HTTPS
else if (protocol == dnsdist::Protocol::DoH) {
- return getDoHCrossProtocolQueryFromDQ(dq, isResponse);
+#ifdef HAVE_LIBH2OEVLOOP
+ if (dq.ids.cs->dohFrontend->d_library == "h2o") {
+ return getDoHCrossProtocolQueryFromDQ(dq, isResponse);
+ }
+#endif /* HAVE_LIBH2OEVLOOP */
+ return getTCPCrossProtocolQueryFromDQ(dq);
}
#endif
else {
struct timeval now;
gettimeofday(&now, nullptr);
- sender->notifyIOError(std::move(query->query.d_idstate), now);
+ TCPResponse tresponse(std::move(query->query));
+ sender->notifyIOError(now, std::move(tresponse));
return true;
}
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "base64.hh"
+#include "dnsdist-nghttp2-in.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsparser.hh"
+
+#ifdef HAVE_NGHTTP2
+
+#if 0
+class IncomingDoHCrossProtocolContext : public CrossProtocolContext
+{
+public:
+ IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID): CrossProtocolContext(std::move(query.d_buffer)), d_connection(connection), d_query(std::move(query))
+ {
+ }
+
+ std::optional<std::string> getHTTPPath() const override
+ {
+ return d_query.d_path;
+ }
+
+ std::optional<std::string> getHTTPScheme() const override
+ {
+ return d_query.d_scheme;
+ }
+
+ std::optional<std::string> getHTTPHost() const override
+ {
+ return d_query.d_host;
+ }
+
+ std::optional<std::string> getHTTPQueryString() const override
+ {
+ return d_query.d_queryString;
+ }
+
+ std::optional<HeadersMap> getHTTPHeaders() const override
+ {
+ if (!d_query.d_headers) {
+ return std::nullopt;
+ }
+ return *d_query.d_headers;
+ }
+
+ void handleResponse(PacketBuffer&& response, InternalQueryState&& state) override
+ {
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ }
+
+ void handleTimeout() override
+ {
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ }
+
+ ~IncomingDoHCrossProtocolContext() override
+ {
+ }
+
+private:
+ std::weak_ptr<IncomingHTTP2Connection> d_connection;
+ IncomingHTTP2Connection::PendingQuery d_query;
+ IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+#endif
+
+class IncomingDoHCrossProtocolContext : public DOHUnitInterface
+{
+public:
+ IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID) :
+ d_connection(connection), d_query(std::move(query)), d_streamID(streamID)
+ {
+ }
+
+ std::string getHTTPPath() const override
+ {
+ return d_query.d_path;
+ }
+
+ const std::string& getHTTPScheme() const override
+ {
+ return d_query.d_scheme;
+ }
+
+ const std::string& getHTTPHost() const override
+ {
+ return d_query.d_host;
+ }
+
+ std::string getHTTPQueryString() const override
+ {
+ return d_query.d_queryString;
+ }
+
+ const HeadersMap& getHTTPHeaders() const override
+ {
+ if (!d_query.d_headers) {
+ static const HeadersMap empty{};
+ return empty;
+ }
+ return *d_query.d_headers;
+ }
+
+ void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") override
+ {
+ d_query.d_statusCode = statusCode;
+ d_query.d_response = std::move(body);
+ d_query.d_contentTypeOut = contentType;
+ }
+
+ void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds) override
+ {
+ std::unique_ptr<DOHUnitInterface> unit(this);
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+
+ state.du = std::move(unit);
+ TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr);
+ resp.d_ds = ds;
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+ conn->handleResponse(now, std::move(resp));
+ }
+
+ void handleTimeout() override
+ {
+ std::unique_ptr<DOHUnitInterface> unit(this);
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+ TCPResponse resp;
+ resp.d_idstate.d_streamID = d_streamID;
+ conn->notifyIOError(now, std::move(resp));
+ }
+
+ ~IncomingDoHCrossProtocolContext() override
+ {
+ }
+
+ std::weak_ptr<IncomingHTTP2Connection> d_connection;
+ IncomingHTTP2Connection::PendingQuery d_query;
+ IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+
+void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPResponse&& response)
+{
+ if (std::this_thread::get_id() != d_creatorThreadID) {
+ handleCrossProtocolResponse(now, std::move(response));
+ return;
+ }
+
+ auto& state = response.d_idstate;
+ if (state.forwardedOverUDP) {
+ dnsheader* responseDH = reinterpret_cast<struct dnsheader*>(response.d_buffer.data());
+
+ if (responseDH->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) {
+ auto& query = *state.d_packet;
+ dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(query.data() + state.d_proxyProtocolPayloadSize);
+ /* restoring the original ID */
+ queryDH->id = state.origID;
+
+ state.forwardedOverUDP = false;
+ auto cpq = getCrossProtocolQuery(std::move(query), std::move(state), response.d_ds);
+ cpq->query.d_proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0;
+ if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
+ return;
+ }
+ else {
+ vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
+ notifyIOError(now, std::move(response));
+ return;
+ }
+ }
+ }
+
+ IncomingTCPConnectionState::handleResponse(now, std::move(response));
+}
+
+std::unique_ptr<DOHUnitInterface> IncomingHTTP2Connection::getDOHUnit(uint32_t streamID)
+{
+ auto query = std::move(d_currentStreams.at(streamID));
+ return std::make_unique<IncomingDoHCrossProtocolContext>(std::move(query), std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this()), streamID);
+}
+
+void IncomingHTTP2Connection::restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&& unit)
+{
+ auto context = std::unique_ptr<IncomingDoHCrossProtocolContext>(dynamic_cast<IncomingDoHCrossProtocolContext*>(unit.release()));
+ d_currentStreams.at(context->d_streamID) = std::move(context->d_query);
+}
+
+void IncomingHTTP2Connection::restoreContext(uint32_t streamID, IncomingHTTP2Connection::PendingQuery&& context)
+{
+ d_currentStreams.at(streamID) = std::move(context);
+}
+
+IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now) :
+ IncomingTCPConnectionState(std::move(ci), threadData, now)
+{
+ nghttp2_session_callbacks* cbs = nullptr;
+ if (nghttp2_session_callbacks_new(&cbs) != 0) {
+ throw std::runtime_error("Unable to create a callback object for a new incoming HTTP/2 session");
+ }
+ std::unique_ptr<nghttp2_session_callbacks, void (*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
+ cbs = nullptr;
+
+ nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
+ nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
+ nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
+ nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback);
+ nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
+ nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
+ nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback);
+
+ nghttp2_session* sess = nullptr;
+ if (nghttp2_session_server_new(&sess, callbacks.get(), this) != 0) {
+ throw std::runtime_error("Coult not allocate a new incoming HTTP/2 session");
+ }
+
+ d_session = std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)>(sess, nghttp2_session_del);
+ sess = nullptr;
+}
+
+bool IncomingHTTP2Connection::checkALPN()
+{
+ constexpr std::array<uint8_t, 2> h2{'h', '2'};
+ auto protocols = d_handler.getNextProtocol();
+ if (protocols.size() == h2.size() && memcmp(protocols.data(), h2.data(), h2.size()) == 0) {
+ return true;
+ }
+ vinfolog("DoH connection from %s expected ALPN value 'h2', got '%s'", d_ci.remote.toStringWithPort(), std::string(protocols.begin(), protocols.end()));
+ return false;
+}
+
+void IncomingHTTP2Connection::handleConnectionReady()
+{
+ constexpr std::array<nghttp2_settings_entry, 1> iv{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100U}}};
+ auto ret = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv.data(), iv.size());
+ if (ret != 0) {
+ throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+ }
+ ret = nghttp2_session_send(d_session.get());
+ if (ret != 0) {
+ throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+ }
+}
+
+void IncomingHTTP2Connection::handleIO()
+{
+ IOState iostate = IOState::Done;
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
+ try {
+ if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
+ stopIO();
+ d_connectionDied = true;
+ return;
+ }
+
+ if (d_state == State::doingHandshake) {
+ iostate = d_handler.tryHandshake();
+ if (iostate == IOState::Done) {
+ handleHandshakeDone(now);
+ if (d_handler.isTLS()) {
+ if (!checkALPN()) {
+ d_connectionDied = true;
+ stopIO();
+ return;
+ }
+ }
+
+ if (expectProxyProtocolFrom(d_ci.remote)) {
+ d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
+ d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+ d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+ }
+ else {
+ d_state = State::waitingForQuery;
+ handleConnectionReady();
+ }
+ }
+ }
+
+ if (d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+ auto status = handleProxyProtocolPayload();
+ if (status == ProxyProtocolResult::Done) {
+ d_currentPos = 0;
+ d_proxyProtocolNeed = 0;
+ d_buffer.clear();
+ d_state = State::waitingForQuery;
+ handleConnectionReady();
+ }
+ else if (status == ProxyProtocolResult::Error) {
+ d_connectionDied = true;
+ stopIO();
+ return;
+ }
+ }
+
+ if (d_state == State::waitingForQuery) {
+ readHTTPData();
+ }
+
+ if (!d_connectionDied) {
+ auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+ if (nghttp2_session_want_read(d_session.get())) {
+ d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none);
+ }
+ if (nghttp2_session_want_write(d_session.get())) {
+ d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none);
+ }
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ d_connectionDied = true;
+ stopIO();
+ }
+}
+
+ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+ bool bufferWasEmpty = conn->d_out.empty();
+ conn->d_out.insert(conn->d_out.end(), data, data + length);
+
+ if (bufferWasEmpty) {
+ try {
+ auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
+ if (state == IOState::Done) {
+ conn->d_out.clear();
+ conn->d_outPos = 0;
+ if (!conn->isIdle()) {
+ conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+ }
+ else {
+ conn->watchForRemoteHostClosingConnection();
+ }
+ }
+ else {
+ conn->updateIO(state, handleWritableIOCallback);
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to write (send) to incoming HTTP connection: %s", e.what());
+ conn->handleIOError();
+ }
+ }
+
+ return length;
+}
+
+static const std::unordered_map<std::string, std::string> s_constants{
+ {"200-value", "200"},
+ {"method-name", ":method"},
+ {"method-value", "POST"},
+ {"scheme-name", ":scheme"},
+ {"scheme-value", "https"},
+ {"authority-name", ":authority"},
+ {"x-forwarded-for-name", "x-forwarded-for"},
+ {"path-name", ":path"},
+ {"content-length-name", "content-length"},
+ {"status-name", ":status"},
+ {"location-name", "location"},
+ {"accept-name", "accept"},
+ {"accept-value", "application/dns-message"},
+ {"cache-control-name", "cache-control"},
+ {"content-type-name", "content-type"},
+ {"content-type-value", "application/dns-message"},
+ {"user-agent-name", "user-agent"},
+ {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"},
+ {"x-forwarded-port-name", "x-forwarded-port"},
+ {"x-forwarded-proto-name", "x-forwarded-proto"},
+ {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"},
+ {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"},
+ {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"},
+ {"x-forwarded-proto-value-dns-over-http", "dns-over-http"},
+ {"x-forwarded-proto-value-dns-over-https", "dns-over-https"},
+};
+
+static const std::string s_authorityHeaderName(":authority");
+static const std::string s_pathHeaderName(":path");
+static const std::string s_methodHeaderName(":method");
+static const std::string s_schemeHeaderName(":scheme");
+static const std::string s_xForwardedForHeaderName("x-forwarded-for");
+
+void NGHTTP2Headers::addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey)
+{
+ const auto& name = s_constants.at(nameKey);
+ const auto& value = s_constants.at(valueKey);
+
+ headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addCustomDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& name, const std::string_view& value)
+{
+ headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.data())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.data())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string_view& value)
+{
+ const auto& name = s_constants.at(nameKey);
+ NGHTTP2Headers::addCustomDynamicHeader(headers, name, value);
+}
+
+IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResponse&& response)
+{
+ assert(response.d_idstate.d_streamID != -1);
+ auto& context = d_currentStreams.at(response.d_idstate.d_streamID);
+
+ uint32_t statusCode = 200U;
+ std::string contentType;
+ bool sendContentType = true;
+ auto& responseBuffer = context.d_buffer;
+ if (context.d_statusCode != 0) {
+ responseBuffer = std::move(context.d_response);
+ statusCode = context.d_statusCode;
+ contentType = std::move(context.d_contentTypeOut);
+ }
+ else {
+ responseBuffer = std::move(response.d_buffer);
+ }
+
+ sendResponse(response.d_idstate.d_streamID, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
+ handleResponseSent(response);
+
+ return IOState::Done;
+}
+
+void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response)
+{
+ if (std::this_thread::get_id() != d_creatorThreadID) {
+ /* empty buffer will signal an IO error */
+ response.d_buffer.clear();
+ handleCrossProtocolResponse(now, std::move(response));
+ return;
+ }
+
+ assert(response.d_idstate.d_streamID != -1);
+ d_currentStreams.at(response.d_idstate.d_streamID).d_buffer = std::move(response.d_buffer);
+ sendResponse(response.d_idstate.d_streamID, 502, d_ci.cs->dohFrontend->d_customResponseHeaders);
+}
+
+bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType)
+{
+ /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set.
+ */
+ nghttp2_data_provider data_provider;
+
+ data_provider.source.ptr = this;
+ data_provider.read_callback = [](nghttp2_session*, IncomingHTTP2Connection::StreamID stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* cb_data) -> ssize_t {
+ auto connection = reinterpret_cast<IncomingHTTP2Connection*>(cb_data);
+ auto& obj = connection->d_currentStreams.at(stream_id);
+ size_t toCopy = 0;
+ if (obj.d_queryPos < obj.d_buffer.size()) {
+ size_t remaining = obj.d_buffer.size() - obj.d_queryPos;
+ toCopy = length > remaining ? remaining : length;
+ memcpy(buf, &obj.d_buffer.at(obj.d_queryPos), toCopy);
+ obj.d_queryPos += toCopy;
+ }
+
+ if (obj.d_queryPos >= obj.d_buffer.size()) {
+ *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+ }
+ return toCopy;
+ };
+
+ const auto& df = d_ci.cs->dohFrontend;
+ auto& responseBody = d_currentStreams.at(streamID).d_buffer;
+
+ std::vector<nghttp2_nv> headers;
+ std::string responseCodeStr;
+ std::string cacheControlValue;
+ std::string location;
+ /* remember that dynamic header values should be kept alive
+ until we have called nghttp2_submit_response(), at least */
+
+ if (responseCode == 200) {
+ NGHTTP2Headers::addStaticHeader(headers, "status-name", "200-value");
+ ++df->d_validresponses;
+ ++df->d_http2Stats.d_nb200Responses;
+
+ if (addContentType) {
+ if (contentType.empty()) {
+ NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value");
+ }
+ else {
+ NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", contentType);
+ }
+ }
+
+ if (df->d_sendCacheControlHeaders && responseBody.size() > sizeof(dnsheader)) {
+ uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+ if (minTTL != std::numeric_limits<uint32_t>::max()) {
+ cacheControlValue = "max-age=" + std::to_string(minTTL);
+ NGHTTP2Headers::addDynamicHeader(headers, "cache-control-name", cacheControlValue);
+ }
+ }
+ }
+ else {
+ responseCodeStr = std::to_string(responseCode);
+ NGHTTP2Headers::addDynamicHeader(headers, "status-name", responseCodeStr);
+
+ if (responseCode >= 300 && responseCode < 400) {
+ location = std::string(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+ NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/html; charset=utf-8");
+ NGHTTP2Headers::addDynamicHeader(headers, "location-name", location);
+ static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
+ static const std::string s_redirectEnd{"\">here</A>"};
+ responseBody.reserve(s_redirectStart.size() + responseBody.size() + s_redirectEnd.size());
+ responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
+ responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
+ ++df->d_redirectresponses;
+ }
+ else {
+ ++df->d_errorresponses;
+ switch (responseCode) {
+ case 400:
+ ++df->d_http2Stats.d_nb400Responses;
+ break;
+ case 403:
+ ++df->d_http2Stats.d_nb403Responses;
+ break;
+ case 500:
+ ++df->d_http2Stats.d_nb500Responses;
+ break;
+ case 502:
+ ++df->d_http2Stats.d_nb502Responses;
+ break;
+ default:
+ ++df->d_http2Stats.d_nbOtherResponses;
+ break;
+ }
+
+ if (!responseBody.empty()) {
+ NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/plain; charset=utf-8");
+ }
+ else {
+ static const std::string invalid{"invalid DNS query"};
+ static const std::string notAllowed{"dns query not allowed"};
+ static const std::string noDownstream{"no downstream server available"};
+ static const std::string internalServerError{"Internal Server Error"};
+
+ switch (responseCode) {
+ case 400:
+ responseBody.insert(responseBody.begin(), invalid.begin(), invalid.end());
+ break;
+ case 403:
+ responseBody.insert(responseBody.begin(), notAllowed.begin(), notAllowed.end());
+ break;
+ case 502:
+ responseBody.insert(responseBody.begin(), noDownstream.begin(), noDownstream.end());
+ break;
+ case 500:
+ /* fall-through */
+ default:
+ responseBody.insert(responseBody.begin(), internalServerError.begin(), internalServerError.end());
+ break;
+ }
+ }
+ }
+ }
+
+ const std::string contentLength = std::to_string(responseBody.size());
+ NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", contentLength);
+
+ for (const auto& [key, value] : customResponseHeaders) {
+ NGHTTP2Headers::addCustomDynamicHeader(headers, key, value);
+ }
+
+ auto ret = nghttp2_submit_response(d_session.get(), streamID, headers.data(), headers.size(), &data_provider);
+ if (ret != 0) {
+ d_currentStreams.erase(streamID);
+ vinfolog("Error submitting HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+ return false;
+ }
+
+ ret = nghttp2_session_send(d_session.get());
+ if (ret != 0) {
+ d_currentStreams.erase(streamID);
+ vinfolog("Error flushing HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+ return false;
+ }
+
+ return true;
+}
+
+static void processForwardedForHeader(const std::unique_ptr<HeadersMap>& headers, ComboAddress& remote)
+{
+ if (!headers) {
+ return;
+ }
+
+ auto it = headers->find(s_xForwardedForHeaderName);
+ if (it == headers->end()) {
+ return;
+ }
+
+ std::string_view value = it->second;
+ try {
+ auto pos = value.rfind(',');
+ if (pos != std::string_view::npos) {
+ ++pos;
+ for (; pos < value.size() && value[pos] == ' '; ++pos) {
+ }
+
+ if (pos < value.size()) {
+ value = value.substr(pos);
+ }
+ }
+ auto newRemote = ComboAddress(std::string(value));
+ remote = newRemote;
+ }
+ catch (const std::exception& e) {
+ vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.what());
+ }
+ catch (const PDNSException& e) {
+ vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.reason);
+ }
+}
+
+static std::optional<PacketBuffer> getPayloadFromPath(const std::string_view& path)
+{
+ std::optional<PacketBuffer> result{std::nullopt};
+
+ if (path.size() <= 5) {
+ return result;
+ }
+
+ auto pos = path.find("?dns=");
+ if (pos == string::npos) {
+ pos = path.find("&dns=");
+ }
+
+ if (pos == string::npos) {
+ return result;
+ }
+
+ // need to base64url decode this
+ string sdns(path.substr(pos + 5));
+ boost::replace_all(sdns, "-", "+");
+ boost::replace_all(sdns, "_", "/");
+
+ // re-add padding that may have been missing
+ switch (sdns.size() % 4) {
+ case 2:
+ sdns.append(2, '=');
+ break;
+ case 3:
+ sdns.append(1, '=');
+ break;
+ }
+
+ PacketBuffer decoded;
+ /* rough estimate so we hopefully don't need a new allocation later */
+ /* We reserve at few additional bytes to be able to add EDNS later */
+ const size_t estimate = ((sdns.size() * 3) / 4);
+ decoded.reserve(estimate);
+ if (B64Decode(sdns, decoded) < 0) {
+ return result;
+ }
+
+ result = std::move(decoded);
+ return result;
+}
+
+void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::PendingQuery&& query, IncomingHTTP2Connection::StreamID streamID)
+{
+ const auto handleImmediateResponse = [this, &query, streamID](uint16_t code, const std::string& reason, PacketBuffer&& response = PacketBuffer()) {
+ if (response.empty()) {
+ query.d_buffer.clear();
+ query.d_buffer.insert(query.d_buffer.begin(), reason.begin(), reason.end());
+ }
+ else {
+ query.d_buffer = std::move(response);
+ }
+ vinfolog("Sending an immediate %d response to incoming DoH query: %s", code, reason);
+ sendResponse(streamID, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
+ };
+
+ ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
+
+ if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) {
+ processForwardedForHeader(query.d_headers, d_proxiedRemote);
+
+ /* second ACL lookup based on the updated address */
+ auto& holders = d_threadData.holders;
+ if (!holders.acl->match(d_proxiedRemote)) {
+ ++dnsdist::metrics::g_stats.aclDrops;
+ vinfolog("Query from %s (%s) (DoH) dropped because of ACL", d_ci.remote.toStringWithPort(), d_proxiedRemote.toStringWithPort());
+ handleImmediateResponse(403, "DoH query not allowed because of ACL");
+ return;
+ }
+
+ if (!d_ci.cs->dohFrontend->d_keepIncomingHeaders) {
+ query.d_headers.reset();
+ }
+ }
+
+ if (d_ci.cs->dohFrontend->d_exactPathMatching) {
+ if (d_ci.cs->dohFrontend->d_urls.count(query.d_path) == 0) {
+ handleImmediateResponse(404, "there is no endpoint configured for this path");
+ return;
+ }
+ }
+ else {
+ bool found = false;
+ for (const auto& path : d_ci.cs->dohFrontend->d_urls) {
+ if (boost::starts_with(query.d_path, path)) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ handleImmediateResponse(404, "there is no endpoint configured for this path");
+ return;
+ }
+ }
+
+ /* the responses map can be updated at runtime, so we need to take a copy of
+ the shared pointer, increasing the reference counter */
+ auto responsesMap = d_ci.cs->dohFrontend->d_responsesMap;
+ if (responsesMap) {
+ for (const auto& entry : *responsesMap) {
+ if (entry->matches(query.d_path)) {
+ const auto& customHeaders = entry->getHeaders();
+ query.d_buffer = entry->getContent();
+ if (entry->getStatusCode() >= 400 && query.d_buffer.size() >= 1) {
+ // legacy trailing 0 from the h2o era
+ query.d_buffer.pop_back();
+ }
+
+ sendResponse(streamID, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false);
+ return;
+ }
+ }
+ }
+
+ if (query.d_buffer.empty() && query.d_method == PendingQuery::Method::Get && !query.d_queryString.empty()) {
+ auto payload = getPayloadFromPath(query.d_queryString);
+ if (payload) {
+ query.d_buffer = std::move(*payload);
+ }
+ else {
+ ++d_ci.cs->dohFrontend->d_badrequests;
+ handleImmediateResponse(400, "DoH unable to decode BASE64-URL");
+ return;
+ }
+ }
+
+ if (query.d_method == PendingQuery::Method::Get) {
+ ++d_ci.cs->dohFrontend->d_getqueries;
+ }
+ else if (query.d_method == PendingQuery::Method::Post) {
+ ++d_ci.cs->dohFrontend->d_postqueries;
+ }
+
+ try {
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+ auto processingResult = handleQuery(std::move(query.d_buffer), now, streamID);
+
+ switch (processingResult) {
+ case QueryProcessingResult::TooSmall:
+ handleImmediateResponse(400, "DoH non-compliant query");
+ break;
+ case QueryProcessingResult::InvalidHeaders:
+ handleImmediateResponse(400, "DoH invalid headers");
+ break;
+ case QueryProcessingResult::Empty:
+ handleImmediateResponse(200, "DoH empty query", std::move(query.d_buffer));
+ break;
+ case QueryProcessingResult::Dropped:
+ handleImmediateResponse(403, "DoH dropped query");
+ break;
+ case QueryProcessingResult::NoBackend:
+ handleImmediateResponse(502, "DoH no backend available");
+ return;
+ case QueryProcessingResult::Forwarded:
+ case QueryProcessingResult::Asynchronous:
+ case QueryProcessingResult::SelfAnswered:
+ break;
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while processing DoH query: %s", e.what());
+ handleImmediateResponse(400, "DoH non-compliant query");
+ return;
+ }
+}
+
+int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+#if 0
+ switch (frame->hd.type) {
+ case NGHTTP2_HEADERS:
+ cerr<<"got headers"<<endl;
+ if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
+ cerr<<"All headers received"<<endl;
+ }
+ if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
+ cerr<<"All headers received - query"<<endl;
+ }
+ break;
+ case NGHTTP2_WINDOW_UPDATE:
+ cerr<<"got window update"<<endl;
+ break;
+ case NGHTTP2_SETTINGS:
+ cerr<<"got settings"<<endl;
+ cerr<<frame->settings.niv<<endl;
+ for (size_t idx = 0; idx < frame->settings.niv; idx++) {
+ cerr<<"- "<<frame->settings.iv[idx].settings_id<<" "<<frame->settings.iv[idx].value<<endl;
+ }
+ break;
+ case NGHTTP2_DATA:
+ cerr<<"got data"<<endl;
+ break;
+ }
+#endif
+
+ if (frame->hd.type == NGHTTP2_GOAWAY) {
+ conn->stopIO();
+ if (conn->isIdle()) {
+ if (nghttp2_session_want_write(conn->d_session.get())) {
+ conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none);
+ }
+ }
+ }
+
+ /* is this the last frame for this stream? */
+ else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) {
+ auto streamID = frame->hd.stream_id;
+ auto stream = conn->d_currentStreams.find(streamID);
+ if (stream != conn->d_currentStreams.end()) {
+ conn->handleIncomingQuery(std::move(stream->second), streamID);
+
+ if (conn->isIdle()) {
+ conn->watchForRemoteHostClosingConnection();
+ }
+ }
+ else {
+ vinfolog("Stream %d NOT FOUND", streamID);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ }
+
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_stream_close_callback(nghttp2_session* session, IncomingHTTP2Connection::StreamID stream_id, uint32_t error_code, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+ if (error_code == 0) {
+ return 0;
+ }
+
+ auto stream = conn->d_currentStreams.find(stream_id);
+ if (stream == conn->d_currentStreams.end()) {
+ /* we don't care, then */
+ return 0;
+ }
+
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+ auto request = std::move(stream->second);
+ conn->d_currentStreams.erase(stream->first);
+
+ if (conn->isIdle()) {
+ conn->watchForRemoteHostClosingConnection();
+ }
+
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+ if (frame->hd.type != NGHTTP2_HEADERS || frame->headers.cat != NGHTTP2_HCAT_REQUEST) {
+ return 0;
+ }
+
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+ auto insertPair = conn->d_currentStreams.insert({frame->hd.stream_id, PendingQuery()});
+ if (!insertPair.second) {
+ /* there is a stream ID collision, something is very wrong! */
+ vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
+ conn->d_connectionDied = true;
+ nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+ auto ret = nghttp2_session_send(conn->d_session.get());
+ if (ret != 0) {
+ vinfolog("Error flushing HTTP response for stream %d from %s: %s", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ return 0;
+ }
+
+ return 0;
+}
+
+static std::string::size_type getLengthOfPathWithoutParameters(const std::string_view& path)
+{
+ auto pos = path.find("?");
+ if (pos == string::npos) {
+ return path.size();
+ }
+
+ return pos;
+}
+
+int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t nameLen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+ if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
+ if (nghttp2_check_header_name(name, nameLen) == 0) {
+ vinfolog("Invalid header name");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+#if HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113
+ if (nghttp2_check_header_value_rfc9113(value, valuelen) == 0) {
+ vinfolog("Invalid header value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113 */
+
+ auto headerMatches = [name, nameLen](const std::string& expected) -> bool {
+ return nameLen == expected.size() && memcmp(name, expected.data(), expected.size()) == 0;
+ };
+
+ auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
+ if (stream == conn->d_currentStreams.end()) {
+ vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ auto& query = stream->second;
+ auto valueView = std::string_view(reinterpret_cast<const char*>(value), valuelen);
+ if (headerMatches(s_pathHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_PATH
+ if (nghttp2_check_path(value, valuelen) == 0) {
+ vinfolog("Invalid path value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_PATH */
+
+ auto pathLen = getLengthOfPathWithoutParameters(valueView);
+ query.d_path = valueView.substr(0, pathLen);
+ if (pathLen < valueView.size()) {
+ query.d_queryString = valueView.substr(pathLen);
+ }
+ }
+ else if (headerMatches(s_authorityHeaderName)) {
+ query.d_host = valueView;
+ }
+ else if (headerMatches(s_schemeHeaderName)) {
+ query.d_scheme = valueView;
+ }
+ else if (headerMatches(s_methodHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_METHOD
+ if (nghttp2_check_method(value, valuelen) == 0) {
+ vinfolog("Invalid method value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_METHOD */
+
+ if (valueView == "GET") {
+ query.d_method = PendingQuery::Method::Get;
+ }
+ else if (valueView == "POST") {
+ query.d_method = PendingQuery::Method::Post;
+ }
+ else {
+ vinfolog("Unsupported method value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ }
+
+ if (conn->d_ci.cs->dohFrontend->d_keepIncomingHeaders || (conn->d_ci.cs->dohFrontend->d_trustForwardedForHeader && headerMatches(s_xForwardedForHeaderName))) {
+ if (!query.d_headers) {
+ query.d_headers = std::make_unique<HeadersMap>();
+ }
+ query.d_headers->insert({std::string(reinterpret_cast<const char*>(name), nameLen), std::string(valueView)});
+ }
+ }
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, IncomingHTTP2Connection::StreamID stream_id, const uint8_t* data, size_t len, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+ auto stream = conn->d_currentStreams.find(stream_id);
+ if (stream == conn->d_currentStreams.end()) {
+ vinfolog("Unable to match the stream ID %d to a known one!", stream_id);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ if (len > std::numeric_limits<uint16_t>::max() || (std::numeric_limits<uint16_t>::max() - stream->second.d_buffer.size()) < len) {
+ vinfolog("Data frame of size %d is too large for a DNS query (we already have %d)", len, stream->second.d_buffer.size());
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len);
+
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data)
+{
+ IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+ vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len));
+ conn->d_connectionDied = true;
+ nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+ auto ret = nghttp2_session_send(conn->d_session.get());
+ if (ret != 0) {
+ vinfolog("Error flushing HTTP response on connection from %s: %s", conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ return 0;
+}
+
+void IncomingHTTP2Connection::readHTTPData()
+{
+ IOStateGuard ioGuard(d_ioState);
+ do {
+ size_t got = 0;
+ d_in.resize(d_in.size() + 512);
+ try {
+ IOState newState = d_handler.tryRead(d_in, got, d_in.size(), true);
+ d_in.resize(got);
+
+ if (got > 0) {
+ /* we got something */
+ auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
+ /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
+ all data should be consumed before returning */
+ if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
+ throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
+ }
+
+ nghttp2_session_send(d_session.get());
+ }
+
+ if (newState == IOState::Done) {
+ if (isIdle()) {
+ watchForRemoteHostClosingConnection();
+ ioGuard.release();
+ break;
+ }
+ }
+ else {
+ if (newState == IOState::NeedWrite) {
+ updateIO(IOState::NeedWrite, handleReadableIOCallback);
+ }
+ ioGuard.release();
+ break;
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to read from HTTP backend connection: %s", e.what());
+ handleIOError();
+ break;
+ }
+ } while (getConcurrentStreamsCount() > 0);
+}
+
+void IncomingHTTP2Connection::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+ auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+ conn->handleIO();
+}
+
+void IncomingHTTP2Connection::handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+ auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+ IOStateGuard ioGuard(conn->d_ioState);
+
+ try {
+ IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
+ if (newState == IOState::NeedRead) {
+ conn->updateIO(IOState::NeedRead, handleWritableIOCallback);
+ }
+ else if (newState == IOState::Done) {
+ conn->d_out.clear();
+ conn->d_outPos = 0;
+ if (!conn->isIdle()) {
+ conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+ }
+ else {
+ conn->watchForRemoteHostClosingConnection();
+ }
+ }
+ ioGuard.release();
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to write (ready) to HTTP backend connection: %s", e.what());
+ conn->handleIOError();
+ }
+}
+
+bool IncomingHTTP2Connection::isIdle() const
+{
+ return getConcurrentStreamsCount() == 0;
+}
+
+void IncomingHTTP2Connection::stopIO()
+{
+ d_ioState->reset();
+}
+
+uint32_t IncomingHTTP2Connection::getConcurrentStreamsCount() const
+{
+ return d_currentStreams.size();
+}
+
+boost::optional<struct timeval> IncomingHTTP2Connection::getIdleClientReadTTD(struct timeval now) const
+{
+ auto idleTimeout = d_ci.cs->dohFrontend->d_idleTimeout;
+ if (g_maxTCPConnectionDuration == 0 && idleTimeout == 0) {
+ return boost::none;
+ }
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+ if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+ return now;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (idleTimeout == 0 || remaining <= static_cast<size_t>(idleTimeout)) {
+ now.tv_sec += remaining;
+ return now;
+ }
+ }
+
+ now.tv_sec += idleTimeout;
+ return now;
+}
+
+void IncomingHTTP2Connection::updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback)
+{
+ boost::optional<struct timeval> ttd{boost::none};
+
+ auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+ if (shared) {
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
+ if (newState == IOState::NeedRead) {
+ if (isIdle()) {
+ ttd = getIdleClientReadTTD(now);
+ }
+ else {
+ ttd = getClientReadTTD(now);
+ }
+ d_ioState->update(newState, callback, shared, ttd);
+ }
+ else if (newState == IOState::NeedWrite) {
+ ttd = getClientWriteTTD(now);
+ d_ioState->update(newState, callback, shared, ttd);
+ }
+ }
+}
+
+void IncomingHTTP2Connection::watchForRemoteHostClosingConnection()
+{
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
+}
+
+void IncomingHTTP2Connection::handleIOError()
+{
+ d_connectionDied = true;
+ nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
+ d_currentStreams.clear();
+ stopIO();
+}
+#endif /* HAVE_NGHTTP2 */
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "config.h"
+#ifdef HAVE_NGHTTP2
+#include <nghttp2/nghttp2.h>
+
+#include "dnsdist-tcp-upstream.hh"
+
+class IncomingHTTP2Connection : public IncomingTCPConnectionState
+{
+public:
+ using StreamID = int32_t;
+
+ class PendingQuery
+ {
+ public:
+ enum class Method : uint8_t
+ {
+ Unknown,
+ Get,
+ Post
+ };
+
+ PacketBuffer d_buffer;
+ PacketBuffer d_response;
+ std::string d_path;
+ std::string d_scheme;
+ std::string d_host;
+ std::string d_queryString;
+ std::string d_sni;
+ std::string d_contentTypeOut;
+ std::unique_ptr<HeadersMap> d_headers;
+ size_t d_queryPos{0};
+ uint32_t d_statusCode{0};
+ Method d_method{Method::Unknown};
+ };
+
+ IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now);
+ ~IncomingHTTP2Connection() = default;
+ void handleIO() override;
+ void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
+ void restoreContext(uint32_t streamID, PendingQuery&& context);
+
+private:
+ static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
+ static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
+ static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, StreamID stream_id, const uint8_t* data, size_t len, void* user_data);
+ static int on_stream_close_callback(nghttp2_session* session, StreamID stream_id, uint32_t error_code, void* user_data);
+ static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data);
+ static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
+ static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data);
+ static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+ static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+ IOState sendResponse(const struct timeval& now, TCPResponse&& response) override;
+ bool forwardViaUDPFirst() const override
+ {
+ return true;
+ }
+ void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&) override;
+ std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID) override;
+
+ void stopIO();
+ bool isIdle() const;
+ uint32_t getConcurrentStreamsCount() const;
+ void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback);
+ void watchForRemoteHostClosingConnection();
+ void handleIOError();
+ bool sendResponse(StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
+ void handleIncomingQuery(PendingQuery&& query, StreamID streamID);
+ bool checkALPN();
+ void readHTTPData();
+ void handleConnectionReady();
+ boost::optional<struct timeval> getIdleClientReadTTD(struct timeval now) const;
+
+ std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)> d_session{nullptr, nghttp2_session_del};
+ std::unordered_map<StreamID, PendingQuery> d_currentStreams;
+ PacketBuffer d_out;
+ PacketBuffer d_in;
+ size_t d_outPos{0};
+ bool d_connectionDied{false};
+};
+
+class NGHTTP2Headers
+{
+public:
+ static void addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey);
+ static void addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string_view& value);
+ static void addCustomDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& name, const std::string_view& value);
+};
+
+#endif /* HAVE_NGHTTP2 */
#endif /* HAVE_NGHTTP2 */
#include "dnsdist-nghttp2.hh"
+#include "dnsdist-nghttp2-in.hh"
#include "dnsdist-tcp.hh"
#include "dnsdist-tcp-downstream.hh"
#include "dnsdist-downstream-connection.hh"
}
}
- request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this(), d_ds));
+ TCPResponse response(std::move(request.d_query));
+ response.d_buffer = std::move(request.d_buffer);
+ response.d_connection = shared_from_this();
+ response.d_ds = d_ds;
+ request.d_sender->handleResponse(now, std::move(response));
}
catch (const std::exception& e) {
vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what());
d_ds->reportTimeoutOrError();
}
- request.d_sender->notifyIOError(std::move(request.d_query.d_idstate), now);
+ TCPResponse response(PacketBuffer(), std::move(request.d_query.d_idstate), nullptr, nullptr);
+ request.d_sender->notifyIOError(now, std::move(response));
}
catch (const std::exception& e) {
vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what());
return getConcurrentStreamsCount() == 0;
}
-const std::unordered_map<std::string, std::string> DoHConnectionToBackend::s_constants = {
- {"method-name", ":method"},
- {"method-value", "POST"},
- {"scheme-name", ":scheme"},
- {"scheme-value", "https"},
- {"accept-name", "accept"},
- {"accept-value", "application/dns-message"},
- {"content-type-name", "content-type"},
- {"content-type-value", "application/dns-message"},
- {"user-agent-name", "user-agent"},
- {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"},
- {"authority-name", ":authority"},
- {"path-name", ":path"},
- {"content-length-name", "content-length"},
- {"x-forwarded-for-name", "x-forwarded-for"},
- {"x-forwarded-port-name", "x-forwarded-port"},
- {"x-forwarded-proto-name", "x-forwarded-proto"},
- {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"},
- {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"},
- {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"},
- {"x-forwarded-proto-value-dns-over-http", "dns-over-http"},
- {"x-forwarded-proto-value-dns-over-https", "dns-over-https"},
-};
-
-void DoHConnectionToBackend::addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey)
-{
- const auto& name = s_constants.at(nameKey);
- const auto& value = s_constants.at(valueKey);
-
- headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
-}
-
-void DoHConnectionToBackend::addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& value)
-{
- const auto& name = s_constants.at(nameKey);
-
- headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
-}
-
void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
{
auto payloadSize = std::to_string(query.d_buffer.size());
headers.reserve(8 + (addXForwarded ? 3 : 0));
/* Pseudo-headers need to come first (rfc7540 8.1.2.1) */
- addStaticHeader(headers, "method-name", "method-value");
- addStaticHeader(headers, "scheme-name", "scheme-value");
- addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName);
- addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath);
- addStaticHeader(headers, "accept-name", "accept-value");
- addStaticHeader(headers, "content-type-name", "content-type-value");
- addStaticHeader(headers, "user-agent-name", "user-agent-value");
- addDynamicHeader(headers, "content-length-name", payloadSize);
+ NGHTTP2Headers::addStaticHeader(headers, "method-name", "method-value");
+ NGHTTP2Headers::addStaticHeader(headers, "scheme-name", "scheme-value");
+ NGHTTP2Headers::addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName);
+ NGHTTP2Headers::addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath);
+ NGHTTP2Headers::addStaticHeader(headers, "accept-name", "accept-value");
+ NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value");
+ NGHTTP2Headers::addStaticHeader(headers, "user-agent-name", "user-agent-value");
+ NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", payloadSize);
/* no need to add these headers for health-check queries */
if (addXForwarded && query.d_idstate.origRemote.getPort() != 0) {
remote = query.d_idstate.origRemote.toString();
remotePort = std::to_string(query.d_idstate.origRemote.getPort());
- addDynamicHeader(headers, "x-forwarded-for-name", remote);
- addDynamicHeader(headers, "x-forwarded-port-name", remotePort);
+ NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-for-name", remote);
+ NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-port-name", remotePort);
if (query.d_idstate.cs != nullptr) {
if (query.d_idstate.cs->isUDP()) {
- addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp");
+ NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp");
}
else if (query.d_idstate.cs->isDoH()) {
if (query.d_idstate.cs->hasTLS()) {
- addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https");
+ NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https");
}
else {
- addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http");
+ NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http");
}
}
else if (query.d_idstate.cs->hasTLS()) {
- addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls");
+ NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls");
}
else {
- addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp");
+ NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp");
}
}
}
downstream->queueQuery(tqs, std::move(query));
}
catch (...) {
- tqs->notifyIOError(std::move(query.d_idstate), now);
+ TCPResponse response(std::move(query));
+ tqs->notifyIOError(now, std::move(response));
}
}
static bool getSerialFromIXFRQuery(TCPQuery& query)
{
try {
- size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0;
+ size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0;
if (query.d_buffer.size() <= (proxyPayloadSize + sizeof(uint16_t))) {
return false;
}
if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
query.d_proxyProtocolPayloadAdded = true;
- query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size();
+ query.d_idstate.d_proxyProtocolPayloadSize = query.d_proxyProtocolPayload.size();
}
}
else if (connectionState == ConnectionState::proxySent) {
if (query.d_proxyProtocolPayloadAdded) {
- if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) {
+ if (query.d_buffer.size() < query.d_idstate.d_proxyProtocolPayloadSize) {
throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
}
- query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize);
+ query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_idstate.d_proxyProtocolPayloadSize);
query.d_proxyProtocolPayloadAdded = false;
- query.d_proxyProtocolPayloadAddedSize = 0;
+ query.d_idstate.d_proxyProtocolPayloadSize = 0;
}
}
if (query.d_idstate.qclass == QClass::IN && query.d_idstate.qtype == QType::IXFR) {
getSerialFromIXFRQuery(query);
}
- editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true);
+ editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0, true);
}
IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
/* this one can't be restarted, sorry */
DEBUGLOG("A XFR for which a response has already been sent cannot be restarted");
try {
- pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now);
+ TCPResponse response(std::move(pending.second.d_query));
+ pending.second.d_sender->notifyIOError(now, std::move(response));
}
catch (const std::exception& e) {
vinfolog("Got an exception while notifying: %s", e.what());
increaseCounters(d_currentQuery.d_query.d_idstate.cs);
auto sender = d_currentQuery.d_sender;
if (sender->active()) {
- sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now);
+ TCPResponse response(std::move(d_currentQuery.d_query));
+ sender->notifyIOError(now, std::move(response));
}
}
increaseCounters(query.d_query.d_idstate.cs);
auto sender = query.d_sender;
if (sender->active()) {
- sender->notifyIOError(std::move(query.d_query.d_idstate), now);
+ TCPResponse response(std::move(query.d_query));
+ sender->notifyIOError(now, std::move(response));
}
}
increaseCounters(response.second.d_query.d_idstate.cs);
auto sender = response.second.d_sender;
if (sender->active()) {
- sender->notifyIOError(std::move(response.second.d_query.d_idstate), now);
+ TCPResponse tresp(std::move(response.second.d_query));
+ sender->notifyIOError(now, std::move(tresp));
}
}
}
if (sender->active()) {
DEBUGLOG("passing response to client connection for "<<ids.qname);
// make sure that we still exist after calling handleResponse()
- sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds));
+ TCPResponse response(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds);
+ sender->handleResponse(now, std::move(response));
}
if (!d_pendingQueries.empty()) {
#include "dolog.hh"
#include "dnsdist-tcp.hh"
+#include "dnsdist-tcp-downstream.hh"
struct TCPCrossProtocolResponse;
class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
{
public:
- IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
+ enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Empty, Dropped, SelfAnswered, NoBackend, Asynchronous };
+ enum class ProxyProtocolResult : uint8_t { Reading, Done, Error };
+
+ IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
{
d_origDest.reset();
d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
- ~IncomingTCPConnectionState();
+ virtual ~IncomingTCPConnectionState();
void resetForNewQuery();
static size_t clearAllDownstreamConnections();
- static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param);
static void updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now);
- static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
-static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+
+ virtual void handleIO();
- /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */
- void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ QueryProcessingResult handleQuery(PacketBuffer&& query, const struct timeval& now, std::optional<int32_t> streamID);
+ virtual void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override;
+ virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response);
+ void handleResponseSent(TCPResponse& currentResponse);
+ void handleHandshakeDone(const struct timeval& now);
+ ProxyProtocolResult handleProxyProtocolPayload();
void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
void terminateClientConnection();
- void queueQuery(TCPQuery&& query);
bool canAcceptNewQueries(const struct timeval& now);
{
return d_ioState != nullptr;
}
+ virtual bool forwardViaUDPFirst() const
+ {
+ return false;
+ }
+ virtual std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID)
+ {
+ throw std::runtime_error("Getting a DOHUnit state from a generic TCP/DoT connection is not supported");
+ }
+ virtual void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&)
+ {
+ throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported");
+ }
+
+ std::unique_ptr<CrossProtocolQuery> getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds);
std::string toString() const
{
return o.str();
}
+ dnsdist::Protocol getProtocol() const;
+
enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
TCPResponse d_currentResponse;
*/
#pragma once
+#include <optional>
#include <unistd.h>
#include "channel.hh"
#include "iputils.hh"
InternalQueryState d_idstate;
std::string d_proxyProtocolPayload;
PacketBuffer d_buffer;
- uint32_t d_proxyProtocolPayloadAddedSize{0};
uint32_t d_ixfrQuerySerial{0};
uint32_t d_xfrMasterSerial{0};
uint32_t d_xfrSerialCount{0};
}
}
+ TCPResponse(TCPQuery&& query) :
+ TCPQuery(std::move(query))
+ {
+ if (d_buffer.size() >= sizeof(dnsheader)) {
+ memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+ }
+ else {
+ memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+ }
+ }
+
bool isAsync() const
{
return d_async;
virtual bool active() const = 0;
virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
- virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0;
+ virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) = 0;
/* whether the connection should be automatically released to the pool after handleResponse()
has been called */
InternalQuery query;
std::shared_ptr<DownstreamState> downstream{nullptr};
- size_t proxyProtocolPayloadSize{0};
bool d_isResponse{false};
};
return handleResponse(now, std::move(response));
}
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+ void notifyIOError(const struct timeval& now, TCPResponse&& response) override
{
+ auto& query = response.d_idstate;
if (!query.du) {
return;
}
if (!holders.acl->match(remote)) {
++dnsdist::metrics::g_stats.aclDrops;
vinfolog("Query from %s (DoH) dropped because of ACL", remote.toStringWithPort());
- h2o_send_error_403(req, "Forbidden", "dns query not allowed because of ACL", 0);
+ h2o_send_error_403(req, "Forbidden", "DoH query not allowed because of ACL", 0);
return 0;
}
return;
}
+ if (dsc->df->d_earlyACLDrop && !dsc->df->d_trustForwardedForHeader && !dsc->holders.acl->match(remote)) {
+ ++dnsdist::metrics::g_stats.aclDrops;
+ vinfolog("Dropping DoH connection from %s because of ACL", remote.toStringWithPort());
+ h2o_socket_close(sock);
+ return;
+ }
+
if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
vinfolog("Dropping DoH connection from %s because we have too many from this client already", remote.toStringWithPort());
h2o_socket_close(sock);
AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_HTTPS], [
AC_MSG_CHECKING([whether to enable incoming DNS over HTTPS (DoH) support])
AC_ARG_ENABLE([dns-over-https],
- AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o) @<:@default=no@:>@]),
+ AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o or nghttp2) @<:@default=no@:>@]),
[enable_dns_over_https=$enableval],
[enable_dns_over_https=no]
)
PKG_CHECK_MODULES([NGHTTP2], [libnghttp2], [
[HAVE_NGHTTP2=1]
AC_DEFINE([HAVE_NGHTTP2], [1], [Define to 1 if you have nghttp2])
+ save_CFLAGS=$CFLAGS
+ save_LIBS=$LIBS
+ CFLAGS="$NGHTTP2_CFLAGS $CFLAGS"
+ LIBS="$NGHTTP2_LIBS $LIBS"
+ AC_CHECK_FUNCS([nghttp2_check_header_value_rfc9113 nghttp2_check_method nghttp2_check_path])
+ CFLAGS=$save_CFLAGS
+ LIBS=$save_LIBS
], [ : ])
])
])
{
}
- void notifyIOError(InternalQueryState&&, const struct timeval&) override
+ void notifyIOError(const struct timeval&, TCPResponse&&) override
{
errorRaised = true;
}
/* add stub implementations, we don't want to include the corresponding object files
and their dependencies */
-// NOLINTNEXTLINE(readability-convert-member-functions-to-static): this is a stub, the real one is not that simple..
bool TLSFrontend::setupTLS()
{
return true;
d_valid = true;
}
- void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+ void handleXFRResponse(const struct timeval&, TCPResponse&&) override
{
}
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+ void notifyIOError(const struct timeval&, TCPResponse&&) override
{
d_error = true;
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
BOOST_CHECK(s_writeBuffer == query);
}
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count);
#endif
}
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
struct timeval later = now;
later.tv_sec += g_tcpRecvTimeout + 1;
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
struct timeval later = now;
later.tv_sec += g_tcpRecvTimeout + 1;
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
}
}
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * 2U);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
}
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
struct timeval later = now;
later.tv_sec += g_tcpRecvTimeout + 1;
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
BOOST_CHECK(s_writeBuffer == query);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
BOOST_CHECK(s_backendWriteBuffer == query);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
BOOST_CHECK(s_backendWriteBuffer == query);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
BOOST_CHECK(s_backendWriteBuffer == query);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
/* set the incoming descriptor as ready! */
dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
struct timeval later = now;
later.tv_sec += backend->d_config.tcpSendTimeout + 1;
auto expiredWriteConns = threadData.mplexer->getTimeouts(later, true);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
struct timeval later = now;
later.tv_sec += backend->d_config.tcpRecvTimeout + 1;
auto expiredConns = threadData.mplexer->getTimeouts(later, false);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
BOOST_CHECK(s_writeBuffer == query);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
BOOST_CHECK(s_writeBuffer == query);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
BOOST_CHECK(s_backendWriteBuffer == query);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count);
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
/* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while ((threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
threadData.mplexer->run(&now);
}
};
auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
- IncomingTCPConnectionState::handleIO(state, now);
+ state->handleIO();
while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
threadData.mplexer->run(&now);
}
size_t getTicketsKeysCount() override;
};
-void dohThread(ClientState* clientState);
+void dohThread(ClientState* cs);
#endif /* HAVE_LIBH2OEVLOOP */
#endif /* HAVE_DNS_OVER_HTTPS */
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
}
#endif /* HAVE_LIBSSL */
+
if (!newCtx) {
#ifdef HAVE_LIBSSL
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
{
-#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
+#ifdef HAVE_DNS_OVER_TLS
/* get the "best" available provider */
if (!params.d_provider.empty()) {
#ifdef HAVE_GNUTLS
#endif /* HAVE_GNUTLS */
#endif /* HAVE_LIBSSL */
-#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
+#endif /* HAVE_DNS_OVER_TLS */
return nullptr;
}
public:
enum class ALPN : uint8_t { Unset, DoT, DoH };
- TLSFrontend(ALPN alpn) : d_alpn(alpn)
+ TLSFrontend(ALPN alpn): d_alpn(alpn)
{
}
class TCPIOHandler
{
public:
- enum class Type : uint8_t { Client, Server };
TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx): d_socket(socket)
{
bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
{
- return true;
+ return false;
}
namespace dnsdist {
return sock
@classmethod
- def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
+ def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if timeout:
# 2.7.9+
if hasattr(ssl, 'create_default_context'):
sslctx = ssl.create_default_context(cafile=caCert)
+ if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(alpn)
sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
else:
sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+ # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+ # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
# this path is not in the URLs map and should lead to a 404
(_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL + "NotPowerDNS", query, caFile=self._caCert, useQueue=False, rawResponse=True)
self.assertTrue(receivedResponse)
- self.assertEqual(receivedResponse, b'not found')
+ self.assertIn(receivedResponse, [b'there is no endpoint configured for this path', b'not found'])
self.assertEqual(self._rcode, 404)
# this path is below one in the URLs map and exactPathMatching is false, so we should be good
(receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 127.0.0.1:42, 127.0.0.1'])
self.assertEqual(self._rcode, 403)
- self.assertEqual(receivedResponse, b'dns query not allowed because of ACL')
+ self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL')
class TestDOHForwardedForNoTrusted(DNSDistDOHTest):
newServer{address="127.0.0.1:%s"}
setACL('192.0.2.1/32')
- addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" })
+ addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {earlyACLDrop=true})
"""
_config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey']
'127.0.0.1')
response.answer.append(rrset)
- (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200'])
+ dropped = False
+ try:
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200'])
+ self.assertEqual(self._rcode, 403)
+ self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL')
+ except pycurl.error as e:
+ dropped = True
- self.assertEqual(self._rcode, 403)
- self.assertEqual(receivedResponse, b'dns query not allowed because of ACL')
+ self.assertTrue(dropped)
class TestDOHFrontendLimits(DNSDistDOHTest):
for idx in range(self._maxTCPConnsPerDOHFrontend + 1):
try:
- conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert))
+ conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=['h2']))
except:
conns.append(None)
elif method == "sendDOHQueryWrapper":
pbMessageType = dnsmessage_pb2.PBDNSMessage.DOH
- print(method)
self.checkProtobufQuery(msg, pbMessageType, query, dns.rdataclass.IN, dns.rdatatype.A, name)
self.assertEqual(len(msg.meta), 5)
tags = {}