From: Remi Gacogne Date: Fri, 6 Aug 2021 15:01:03 +0000 (+0200) Subject: dnsdist: Better downstream DoH support, better DoT/DoH ALPN handling X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~33 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e82bf80f0294c8c9702b29c10d359d4dc3de1bbc;p=thirdparty%2Fpdns.git dnsdist: Better downstream DoH support, better DoT/DoH ALPN handling --- diff --git a/m4/pdns_with_gnutls.m4 b/m4/pdns_with_gnutls.m4 index c693dff81f..b6ad100bbb 100644 --- a/m4/pdns_with_gnutls.m4 +++ b/m4/pdns_with_gnutls.m4 @@ -18,7 +18,7 @@ AC_DEFUN([PDNS_WITH_GNUTLS], [ save_LIBS=$LIBS CFLAGS="$GNUTLS_CFLAGS $CFLAGS" LIBS="$GNUTLS_LIBS $LIBS" - AC_CHECK_FUNCS([gnutls_memset gnutls_session_set_verify_cert gnutls_session_get_verify_cert_status]) + AC_CHECK_FUNCS([gnutls_memset gnutls_session_set_verify_cert gnutls_session_get_verify_cert_status gnutls_alpn_set_protocols]) CFLAGS=$save_CFLAGS LIBS=$save_LIBS diff --git a/m4/pdns_with_libssl.m4 b/m4/pdns_with_libssl.m4 index c42905fd1d..3e32bc4086 100644 --- a/m4/pdns_with_libssl.m4 +++ b/m4/pdns_with_libssl.m4 @@ -17,7 +17,7 @@ AC_DEFUN([PDNS_WITH_LIBSSL], [ save_LIBS=$LIBS CFLAGS="$LIBSSL_CFLAGS $CFLAGS" LIBS="$LIBSSL_LIBS -lcrypto $LIBS" - AC_CHECK_FUNCS([SSL_CTX_set_ciphersuites OCSP_basic_sign SSL_CTX_set_num_tickets SSL_CTX_set_keylog_callback SSL_CTX_get0_privatekey SSL_CTX_set_min_proto_version SSL_set_hostflags]) + AC_CHECK_FUNCS([SSL_CTX_set_ciphersuites OCSP_basic_sign SSL_CTX_set_num_tickets SSL_CTX_set_keylog_callback SSL_CTX_get0_privatekey SSL_CTX_set_min_proto_version SSL_set_hostflags SSL_CTX_set_alpn_protos SSL_CTX_set_next_proto_select_cb SSL_get0_alpn_selected SSL_get0_next_proto_negotiated SSL_CTX_set_alpn_select_cb]) CFLAGS=$save_CFLAGS LIBS=$save_LIBS diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index b1f8465b1b..c63486621d 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -40,6 +40,7 @@ #ifdef LUAJIT_VERSION #include "dnsdist-lua-ffi.hh" #endif /* LUAJIT_VERSION */ +#include "dnsdist-nghttp2.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" #include "dnsdist-secpoll.hh" @@ -528,10 +529,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } ret->d_tlsCtx = getTLSContext(tlsParams); - } - if (vars.count("dohPath")) { - ret->d_dohPath = boost::get(vars.at("dohPath")); + if (vars.count("dohPath")) { + ret->d_dohPath = boost::get(vars.at("dohPath")); + if (ret->d_tlsCtx) { + setupDoHClientProtocolNegotiation(ret->d_tlsCtx); + } + } + else { + setupDoTProtocolNegotiation(ret->d_tlsCtx); + } } /* this needs to be done _AFTER_ the order has been set, diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 4e0fdec163..2ca6320c7f 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -592,6 +592,7 @@ static void handleQuery(std::shared_ptr& state, cons prependSizeToTCPQuery(state->d_buffer, 0); +#warning FIXME: handle DoH backends here auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); bool proxyProtocolPayloadAdded = false; @@ -784,7 +785,15 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptractive() && state->d_state != IncomingTCPConnectionState::State::idle) { - iostate = state->d_ioState->getState(); + if (state->d_ioState->isWaitingForRead()) { + iostate = IOState::NeedRead; + } + else if (state->d_ioState->isWaitingForWrite()) { + iostate = IOState::NeedWrite; + } + else { + iostate = IOState::Done; + } } } else { @@ -860,9 +869,9 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_ci.cs->tcpDiedSendingResponse; } - if (state->d_ioState->getState() == IOState::NeedWrite || state->d_queriesCount == 0) { + if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) { DEBUGLOG("Got an exception while handling TCP query: "<d_ioState->getState() == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); } else { vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); @@ -1018,15 +1027,19 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par delete tmp; tmp = nullptr; - auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); + try { + auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); - prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); - downstream->queueQuery(tqs, std::move(query)); + prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); + downstream->queueQuery(tqs, std::move(query)); + } + catch (...) { + tqs->notifyIOError(std::move(query.d_idstate), now); + } } catch (...) { delete tmp; tmp = nullptr; - throw; } } diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 592a21f6d4..b248a98a1e 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -34,9 +34,6 @@ #include "threadname.hh" #include "sstuff.hh" -#warning remove me -#include "dnswriter.hh" - std::atomic g_dohStatesDumpRequested{0}; std::unique_ptr g_dohClientThreads{nullptr}; @@ -45,28 +42,24 @@ class DoHConnectionToBackend: public TCPConnectionToBackend public: DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now); - void handleTimeout(const struct timeval& now, bool write) override - { -#warning FIXME: we should notify the owners of pending queries / responses - } - + void handleTimeout(const struct timeval& now, bool write) override; void queueQuery(std::shared_ptr& sender, TCPQuery&& query) override; std::string toString() const override { ostringstream o; - //o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket"); + o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", "< d_sender{nullptr}; TCPQuery d_query; PacketBuffer d_buffer; + uint16_t d_responseCode{0}; bool d_finished{false}; }; + void addToIOState(IOState state, FDMultiplexer::callbackfunc_t callback); void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback); void stopIO(); void handleResponse(PendingRequest&& request); + void handleResponseError(PendingRequest&& request, const struct timeval& now); + uint32_t getConcurrentStreamsCount() const; - //std::deque d_pendingQueries; + size_t getUsageCount() const + { + auto ref = shared_from_this(); + return ref.use_count(); + } + + static const std::unordered_map s_constants; std::unique_ptr d_session{nullptr, nghttp2_session_del}; std::unordered_map d_currentStreams; PacketBuffer d_out; PacketBuffer d_in; + size_t d_queryPos{0}; size_t d_outPos{0}; size_t d_inPos{0}; + uint32_t d_highestStreamID{0}; }; +class DownstreamDoHConnectionsManager +{ +public: + static std::shared_ptr getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now); + static void releaseDownstreamConnection(std::shared_ptr&& conn); + static void cleanupClosedConnections(struct timeval now); + static size_t clear(); + + static void setMaxCachedConnectionsPerDownstream(size_t max) + { + s_maxCachedConnectionsPerDownstream = max; + } + + static void setCleanupInterval(uint16_t interval) + { + s_cleanupInterval = interval; + } + +private: + static thread_local map>> t_downstreamConnections; + static size_t s_maxCachedConnectionsPerDownstream; + static time_t s_nextCleanup; + static uint16_t s_cleanupInterval; +}; + +uint32_t DoHConnectionToBackend::getConcurrentStreamsCount() const +{ + return d_currentStreams.size(); +} void DoHConnectionToBackend::handleResponse(PendingRequest&& request) { - cerr<<"handle response!"<handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this())); } +void DoHConnectionToBackend::handleResponseError(PendingRequest&& request, const struct timeval& now) +{ + request.d_sender->notifyIOError(std::move(request.d_query.d_idstate), now); +} + +void DoHConnectionToBackend::handleTimeout(const struct timeval& now, bool write) +{ + d_connectionDied = true; + for (auto& request : d_currentStreams) { + handleResponseError(std::move(request.second), now); + } + d_currentStreams.clear(); +} + +bool DoHConnectionToBackend::canBeReused() const +{ + if (d_connectionDied) { + return false; + } + const uint32_t maximumStreamID = (static_cast(1) << 31) - 1; + if (d_highestStreamID == maximumStreamID) { + return false; + } + + //cerr<<"Got "< DoHConnectionToBackend::s_constants = { + { "method-name", ":method" }, + { "method-value", "POST" }, +}; + void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) { /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE @@ -122,8 +192,9 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, and that it is already lowercased. */ auto payloadSize = std::to_string(query.d_buffer.size()); d_currentQuery = std::move(query); + d_queryPos = 0; const nghttp2_nv hdrs[] = { - MAKE_NV2(":method", "POST"), + { const_cast(reinterpret_cast(s_constants.at("method-name").c_str())), const_cast(reinterpret_cast(s_constants.at("method-value").c_str())), s_constants.at("method-name").size(), s_constants.at("method-value").size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE }, MAKE_NV2(":scheme", "https"), MAKE_NV(":authority", d_ds->d_tlsSubjectName.c_str(), d_ds->d_tlsSubjectName.size()), MAKE_NV(":path", d_ds->d_dohPath.c_str(), d_ds->d_dohPath.size()), @@ -135,36 +206,36 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set */ - cerr<<"Remote size window is "< ssize_t { - cerr<<"in data provider"<(user_data); - if (userData->d_inPos >= userData->d_currentQuery.d_buffer.size()) { + size_t toCopy = 0; + if (userData->d_queryPos < userData->d_currentQuery.d_buffer.size()) { + size_t remaining = userData->d_currentQuery.d_buffer.size()- userData->d_queryPos; + toCopy = length > remaining ? remaining : length; + memcpy(buf, &userData->d_currentQuery.d_buffer.at(userData->d_queryPos), toCopy); + userData->d_queryPos += toCopy; + } + + if (userData->d_queryPos >= userData->d_currentQuery.d_buffer.size()) { *data_flags |= NGHTTP2_DATA_FLAG_EOF; - cerr<<"EOF"<d_currentQuery.d_buffer.size()- userData->d_inPos; - size_t toCopy = length > remaining ? remaining : length; - memcpy(buf, &userData->d_currentQuery.d_buffer.at(userData->d_inPos), toCopy); - userData->d_inPos += toCopy; - cerr<& sender, request.d_sender = std::move(sender); auto insertPair = d_currentStreams.insert({stream_id, std::move(request)}); if (!insertPair.second) { - cerr<<"collision!!"<& c void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param) { - cerr<<"in "<<__PRETTY_FUNCTION__<<", param is "<>(param); if (fd != conn->getHandle()) { throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); @@ -205,23 +277,22 @@ void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::fun do { conn->d_inPos = 0; conn->d_in.resize(conn->d_in.size() + 512); - cerr<<"trying to read "<d_in.size()<d_in.size()<d_handler->tryRead(conn->d_in, conn->d_inPos, conn->d_in.size(), true); // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size()); - cerr<<"got a "<<(int)newState<<" state and "<d_inPos<<" bytes"<d_inPos<<" bytes"<d_in.resize(conn->d_inPos); if (newState == IOState::Done) { auto readlen = nghttp2_session_mem_recv(conn->d_session.get(), conn->d_in.data(), conn->d_inPos); - cerr<<"nghttp2_session_mem_recv returned "< 0 && static_cast(readlen) < conn->d_inPos) { cerr<<"Fatal error: "<d_session.get()); - cerr<<"nghttp2_session_send returned "<d_session.get()); } else { if (newState == IOState::NeedWrite) { @@ -232,34 +303,30 @@ void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::fun } } catch (const std::exception& e) { - cerr<<"got exception "<getConcurrentStreamsCount() > 0); } void DoHConnectionToBackend::handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param) { - cerr<<"in "<<__PRETTY_FUNCTION__<<", param is "<>(param); if (fd != conn->getHandle()) { throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); } IOStateGuard ioGuard(conn->d_ioState); - cerr<<"trying to write "<d_out.size()-conn->d_outPos<d_out.size()-conn->d_outPos<d_handler->tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); - cerr<<"got a "<<(int)newState<<" state, "<d_out.size()-conn->d_inPos<<" bytes remaining"<d_out.size()-conn->d_outPos<<" bytes remaining"<updateIO(IOState::NeedRead, handleWritableIOCallback); } else if (newState == IOState::Done) { + ++conn->d_queries; conn->d_out.clear(); conn->d_outPos = 0; conn->stopIO(); @@ -268,12 +335,8 @@ void DoHConnectionToBackend::handleWritableIOCallback(int fd, FDMultiplexer::fun ioGuard.release(); } catch (const std::exception& e) { - cerr<<"got exception "< ttd{boost::none}; + if (state == IOState::NeedRead) { + ttd = getBackendReadTTD(now); + } + else if (isFresh() && d_queries == 0) { + /* first write just after the non-blocking connect */ + ttd = getBackendConnectTTD(now); + } + else { + ttd = getBackendWriteTTD(now); + } + + auto shared = std::dynamic_pointer_cast(shared_from_this()); + if (shared) { + if (state == IOState::NeedRead) { + d_ioState->add(state, callback, shared, ttd); + } + else if (state == IOState::NeedWrite) { + d_ioState->add(state, callback, shared, ttd); + } + } +} + ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) { - cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); - bool bufferWasEmpty = userData->d_out.empty(); - userData->d_out.insert(userData->d_out.end(), data, data + length); + DoHConnectionToBackend* conn = reinterpret_cast(user_data); + bool bufferWasEmpty = conn->d_out.empty(); + conn->d_out.insert(conn->d_out.end(), data, data + length); if (bufferWasEmpty) { - auto state = userData->d_handler->tryWrite(userData->d_out, userData->d_outPos, userData->d_out.size()); - if (state == IOState::Done) { - userData->d_out.clear(); -#warning FIXME from now on we need to read, as we might get an answer - cerr<<"FIXME now we need to read!"<addToIOState(IOState::NeedRead); - //} + try { + auto state = conn->d_handler->tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); + if (state == IOState::Done) { + ++conn->d_queries; + conn->d_out.clear(); + conn->d_outPos = 0; + conn->addToIOState(IOState::NeedRead, handleReadableIOCallback); + } + else { + conn->updateIO(state, handleWritableIOCallback); + } } - else { -#warning write me should be addIO() instead, perhaps? - cerr<<"now we need to wait for a writable (or readable) socket"<updateIO(state, handleWritableIOCallback); + catch (const std::exception& e) { + cerr<<"Exception while trying to write (send) to HTTP backend connection: "<(user_data); - cerr<<"Frame type is "<hd.type)<hd.type)<hd.type) { case NGHTTP2_HEADERS: cerr<<"got headers"<hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { auto stream = conn->d_currentStreams.find(frame->hd.stream_id); if (stream != conn->d_currentStreams.end()) { - cerr<<"Stream "<hd.stream_id<<" is now finished"<hd.stream_id<<" is now finished"<second.d_finished = true; auto request = std::move(stream->second); conn->d_currentStreams.erase(stream->first); - conn->handleResponse(std::move(request)); + if (request.d_responseCode == 200U) { + conn->handleResponse(std::move(request)); + } else { + vinfolog("HTTP response has a non-200 status code: %d", request.d_responseCode); + struct timeval now; + gettimeofday(&now, nullptr); + + conn->handleResponseError(std::move(request), now); + } + if (conn->getConcurrentStreamsCount() == 0) { + conn->stopIO(); + } } else { - cerr<<"Stream "<hd.stream_id<<" NOT FOUND"<hd.stream_id); + conn->d_connectionDied = true; + return NGHTTP2_ERR_CALLBACK_FAILURE; } } @@ -387,270 +485,134 @@ int DoHConnectionToBackend::on_frame_recv_callback(nghttp2_session* session, con } int DoHConnectionToBackend::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data) { - cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); - cerr<<"Got data of size "<d_currentStreams.find(stream_id); if (stream == conn->d_currentStreams.end()) { - cerr<<"Unable to match the stream ID "<d_connectionDied = true; + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + if (len > std::numeric_limits::max() || (std::numeric_limits::max() - stream->second.d_buffer.size()) < len) { + vinfolog("Data frame of size %d is too large for a DNS response (we already have %d)", len, stream->second.d_buffer.size()); + conn->d_connectionDied = true; return NGHTTP2_ERR_CALLBACK_FAILURE; } + stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len); if (stream->second.d_finished) { - cerr<<"we now have the full response!"<(data), len)<second); conn->d_currentStreams.erase(stream->first); - conn->handleResponse(std::move(request)); - cerr<(data), len)<handleResponse(std::move(request)); + } else { + vinfolog("HTTP response has a non-200 status code: %d", request.d_responseCode); + struct timeval now; + gettimeofday(&now, nullptr); + + conn->handleResponseError(std::move(request), now); + } + if (conn->getConcurrentStreamsCount() == 0) { + conn->stopIO(); + } } else { - cerr<<"but the stream is not finished yet"<(user_data); + DoHConnectionToBackend* conn = reinterpret_cast(user_data); + + if (error_code == 0) { + return 0; + } cerr<<"Stream "<second); + conn->d_currentStreams.erase(stream->first); -int DoHConnectionToBackend::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) { - cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); + //cerr<<"in "<<__PRETTY_FUNCTION__<<", looking for a connection to send a query of size "<d_mplexer, conn->d_ds, now); + downstream->queueQuery(request.d_sender, std::move(request.d_query)); - switch (frame->hd.type) { - case NGHTTP2_HEADERS: - if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { - /* Print response headers for the initiated request. */ - cerr<<"got header for "<hd.stream_id<<":"<(name), namelen)<(value), valuelen)<getConcurrentStreamsCount()<<" concurrent connections"<getConcurrentStreamsCount() == 0) { + //cerr<<"stopping IO"<stopIO(); + //cerr<<"our current refcnt is now "<getUsageCount()<(user_data); +int DoHConnectionToBackend::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) { + DoHConnectionToBackend* conn = reinterpret_cast(user_data); + const std::string status(":status"); switch (frame->hd.type) { case NGHTTP2_HEADERS: if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { - cerr<<"Response headers for stream ID="<hd.stream_id<(user_data); - - return 0; -} - -#if 0 -static void doReadData(DoHConnectionToBackend& userData) -{ - do { - size_t pos = 0; - userData.d_in.resize(512); - cerr<<"trying to read "<read(userData.d_in.data(), userData.d_in.size(), timeval{2, 0}, timeval{2, 0}, true); - // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size()); - cerr<<"got "< 0) { - auto readlen = nghttp2_session_mem_recv(userData.d_session.get(), userData.d_in.data(), pos); - cerr<<"nghttp2_session_mem_recv returned "<hd.stream_id<<":"<(name), namelen)<(value), valuelen)<d_currentStreams.find(frame->hd.stream_id); + if (stream == conn->d_currentStreams.end()) { + vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id); + conn->d_connectionDied = true; + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + try { + stream->second.d_responseCode = pdns_stou(std::string(reinterpret_cast(value), valuelen)); + } + catch (...) { + vinfolog("Error parsing the status header for stream ID %d", frame->hd.stream_id); + conn->d_connectionDied = true; + return NGHTTP2_ERR_CALLBACK_FAILURE; } - int rv = nghttp2_session_send(userData.d_session.get()); - cerr<<"nghttp2_session_send returned "< tlsCtx = getTLSContext(tlsParams); - - Socket sock(remote.sin4.sin_family, SOCK_STREAM); - // FIXME - auto fd = sock.getHandle(); - setTCPNoDelay(fd); - DoHConnectionToBackend userData; - userData.d_handler = std::make_unique(host, sock.releaseHandle(), timeval{2, 0}, tlsCtx, time(nullptr)); - userData.d_handler->connect(true, remote, timeval{2, 0}); - - /* check ALPN: -SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); -#if OPENSSL_VERSION_NUMBER >= 0x10002000L - if (alpn == NULL) { - SSL_get0_alpn_selected(ssl, &alpn, &alpnlen); - } -#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L - - if (alpn == NULL || alpnlen != 2 || memcmp("h2", alpn, 2) != 0) { - fprintf(stderr, "h2 is not negotiated\n"); - delete_http2_session_data(session_data); - return; - } - */ - - nghttp2_session_callbacks* cbs = nullptr; - if (nghttp2_session_callbacks_new(&cbs) != 0) { - cerr<<"unable to create a callback object for a new HTTP/2 session"< callbacks(cbs, nghttp2_session_callbacks_del); - cbs = nullptr; - - nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback); - nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback); - nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback); - nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback); - nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback); - nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback); - - nghttp2_session* sess = nullptr; - if (nghttp2_session_client_new(&sess, callbacks.get(), &userData) != 0) { - cerr<<"Coult not allocate a new HTTP/2 session"<(sess, nghttp2_session_del); - sess = nullptr; - - callbacks.reset(); - -#warning we should make the 100 configurable here, as we might want a lower number before receiving the one actually supported by the server -#warning we should also make the window size configurable, but 16M is a nice default - nghttp2_settings_entry iv[] = { - {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100}, - {NGHTTP2_SETTINGS_ENABLE_PUSH, 0}, - {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16*1024*1024} - }; - /* client 24 bytes magic string will be sent by nghttp2 library */ - int rv = nghttp2_submit_settings(userData.d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv)); - if (rv != 0) { - cerr<<"Could not submit SETTINGS: "< pw(userData.d_in, DNSName("doh.dnsdist.org."), QType::A, QClass::IN, 0); - pw.getHeader()->rd = 1; - pw.commit(); - - /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE - to avoid a copy and lowercasing as long as we take care of making sure that the data will outlive the request - and that it is already lowercased. */ - auto payloadSize = std::to_string(userData.d_in.size()); - const nghttp2_nv hdrs[] = { - MAKE_NV2(":method", "POST"), - MAKE_NV2(":scheme", "https"), - MAKE_NV(":authority", host.c_str(), host.size()), - MAKE_NV(":path", path.c_str(), path.size()), - MAKE_NV2("accept", "application/dns-message"), - MAKE_NV2("content-type", "application/dns-message"), - MAKE_NV("content-length", payloadSize.c_str(), payloadSize.size()), - MAKE_NV2("user-agent", "nghttp2-" NGHTTP2_VERSION "/dnsdist") - }; - - /* f data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set - */ - cerr<<"Remote size window is "< ssize_t - { - cerr<<"in data provider"<(user_data); - if (userData->d_inPos >= userData->d_in.size()) { - *data_flags |= NGHTTP2_DATA_FLAG_EOF; - cerr<<"EOF"<d_in.size()- userData->d_inPos; - size_t toCopy = length > remaining ? remaining : length; - memcpy(buf, &userData->d_in.at(userData->d_inPos), toCopy); - userData->d_inPos += toCopy; - cerr<(user_data); + conn->d_connectionDied = true; - doReadData(userData); - cerr<<"After reading data, remote size window is "< ds, std::unique_ptr& mplexer, const struct timeval& now): TCPConnectionToBackend(ds, mplexer, now) { // inherit most of the stuff from the TCPConnectionToBackend() - - /* check ALPN: -SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); -#if OPENSSL_VERSION_NUMBER >= 0x10002000L - if (alpn == NULL) { - SSL_get0_alpn_selected(ssl, &alpn, &alpnlen); - } -#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L - - if (alpn == NULL || alpnlen != 2 || memcmp("h2", alpn, 2) != 0) { - fprintf(stderr, "h2 is not negotiated\n"); - delete_http2_session_data(session_data); - return; - } - */ d_ioState = make_unique(*d_mplexer, d_handler->getDescriptor()); nghttp2_session_callbacks* cbs = nullptr; if (nghttp2_session_callbacks_new(&cbs) != 0) { - cerr<<"unable to create a callback object for a new HTTP/2 session"< callbacks(cbs, nghttp2_session_callbacks_del); @@ -661,12 +623,12 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback); nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback); nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback); - nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback); nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback); nghttp2_session* sess = nullptr; if (nghttp2_session_client_new(&sess, callbacks.get(), this) != 0) { - cerr<<"Coult not allocate a new HTTP/2 session"< getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now); - static void releaseDownstreamConnection(std::shared_ptr&& conn); - static void cleanupClosedConnections(struct timeval now); - static size_t clear(); - - static void setMaxCachedConnectionsPerDownstream(size_t max) - { - s_maxCachedConnectionsPerDownstream = max; - } - - static void setCleanupInterval(uint16_t interval) - { - s_cleanupInterval = interval; - } - -private: - static thread_local map>> t_downstreamConnections; - static size_t s_maxCachedConnectionsPerDownstream; - static time_t s_nextCleanup; - static uint16_t s_cleanupInterval; -}; - struct DoHClientCollection::DoHWorkerThread { DoHWorkerThread() @@ -778,9 +720,94 @@ bool DoHClientCollection::passCrossProtocolQueryToThread(std::unique_ptr>> DownstreamDoHConnectionsManager::t_downstreamConnections; +size_t DownstreamDoHConnectionsManager::s_maxCachedConnectionsPerDownstream{10}; +time_t DownstreamDoHConnectionsManager::s_nextCleanup{0}; +uint16_t DownstreamDoHConnectionsManager::s_cleanupInterval{60}; + +void DownstreamDoHConnectionsManager::cleanupClosedConnections(struct timeval now) +{ + struct timeval freshCutOff = now; + freshCutOff.tv_sec -= 1; + + for (auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) { + for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) { + if (!(*connIt)) { + ++connIt; + continue; + } + + /* don't bother checking freshly used connections */ + if (freshCutOff < (*connIt)->getLastDataReceivedTime()) { + ++connIt; + continue; + } + + if (isTCPSocketUsable((*connIt)->getHandle())) { + ++connIt; + } + else { + connIt = dsIt->second.erase(connIt); + } + } + + if (!dsIt->second.empty()) { + ++dsIt; + } + else { + dsIt = t_downstreamConnections.erase(dsIt); + } + } +} + std::shared_ptr DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now) { - return std::make_shared(ds, mplexer, now); + std::shared_ptr result; + struct timeval freshCutOff = now; + freshCutOff.tv_sec -= 1; + + auto backendId = ds->getID(); + + if (s_cleanupInterval > 0 && (s_nextCleanup == 0 || s_nextCleanup <= now.tv_sec)) { + s_nextCleanup = now.tv_sec + s_cleanupInterval; + //cerr<<"cleaning up"<second; + for (auto listIt = list.begin(); listIt != list.end(); ) { + auto& entry = *listIt; + if (!entry->canBeReused()) { + listIt = list.erase(listIt); + continue; + } + entry->setReused(); + /* for connections that have not been used very recently, + check whether they have been closed in the meantime */ + if (freshCutOff < entry->getLastDataReceivedTime()) { + /* used recently enough, skip the check */ + ++ds->tcpReusedConnections; + return entry; + } + + if (isTCPSocketUsable(entry->getHandle())) { + ++ds->tcpReusedConnections; + return entry; + } + + /* otherwise let's try the next one, if any */ + ++listIt; + } + } + + auto newConnection = std::make_shared(ds, mplexer, now); + t_downstreamConnections[backendId].push_back(newConnection); + return newConnection; + } } static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) @@ -812,15 +839,19 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par delete tmp; tmp = nullptr; - auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); - + try { + auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); + #warning what about the proxy protocol payload, here, do we need to remove it? we likely need to handle forward-for headers? - downstream->queueQuery(tqs, std::move(query)); + downstream->queueQuery(tqs, std::move(query)); + } + catch (...) { + tqs->notifyIOError(std::move(query.d_idstate), now); + } } catch (...) { delete tmp; tmp = nullptr; - throw; } } @@ -962,7 +993,27 @@ void DoHClientCollection::addThread() bool initDoHWorkers() { #warning FIXME: number of DoH threads - g_dohClientThreads = std::make_unique(1); + g_dohClientThreads = std::make_unique(4); g_dohClientThreads->addThread(); return true; } + +static bool select_next_proto_callback(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen) { + if (nghttp2_select_next_protocol(out, outlen, in, inlen) <= 0) { + vinfolog("The remote DoH backend did not advertise " NGHTTP2_PROTO_VERSION_ID); + return false; + } + return true; +} + +bool setupDoHClientProtocolNegotiation(std::shared_ptr& ctx) +{ + if (ctx == nullptr) { + return false; + } + /* we want to set the ALPN to h2, if only to mitigate the ALPACA attack */ + const std::vector> h2Alpns = {{'h', '2'}}; + ctx->setALPNProtos(h2Alpns); + ctx->setNextProtocolSelectCallback(select_next_proto_callback); + return true; +} diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.hh b/pdns/dnsdistdist/dnsdist-nghttp2.hh index 9713735823..0775898e81 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.hh +++ b/pdns/dnsdistdist/dnsdist-nghttp2.hh @@ -60,4 +60,7 @@ private: extern std::unique_ptr g_dohClientThreads; extern std::atomic g_dohStatesDumpRequested; +class TLSCtx; + bool initDoHWorkers(); +bool setupDoHClientProtocolNegotiation(std::shared_ptr& ctx); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 2201f2df4c..9370a8a007 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -73,7 +73,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptrd_currentQuery.d_proxyProtocolPayloadAdded) { conn->d_proxyProtocolPayloadSent = true; } - conn->incQueries(); + ++conn->d_queries; conn->d_currentPos = 0; DEBUGLOG("adding a pending response for ID "<d_currentQuery.d_idstate.origID)<<" and QNAME "<d_currentQuery.d_idstate.qname); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh index 9301ad70d8..aafbd3dd90 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -46,11 +46,6 @@ public: return d_fresh; } - void incQueries() - { - ++d_queries; - } - void setReused() { d_fresh = false; @@ -86,7 +81,7 @@ public: } /* whether a connection can be reused for a different client */ - bool canBeReused() const + virtual bool canBeReused() const { if (d_connectionDied) { return false; @@ -126,7 +121,7 @@ public: virtual std::string toString() const { ostringstream o; - o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<& state, bo std::string toString() const { ostringstream o; - o << "Incoming TCP connection from "<getState()) : "empty")<<", queries count is "<getState() : "empty")<<", queries count is "<ids.cs = &cs; setIDStateFromDNSQuestion(du->ids, dq, std::move(qname)); - if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { + if (du->downstream->passCrossProtocolQuery(std::move(cpq))) { return 0; } else { diff --git a/pdns/dnsdistdist/tcpiohandler-mplexer.hh b/pdns/dnsdistdist/tcpiohandler-mplexer.hh index c8d98b8b98..2fbe4b5147 100644 --- a/pdns/dnsdistdist/tcpiohandler-mplexer.hh +++ b/pdns/dnsdistdist/tcpiohandler-mplexer.hh @@ -14,11 +14,11 @@ class IOStateHandler { public: - IOStateHandler(FDMultiplexer& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd), d_currentState(IOState::Done) + IOStateHandler(FDMultiplexer& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd) { } - IOStateHandler(FDMultiplexer& mplexer): d_mplexer(mplexer), d_fd(-1), d_currentState(IOState::Done) + IOStateHandler(FDMultiplexer& mplexer): d_mplexer(mplexer), d_fd(-1) { } @@ -36,9 +36,14 @@ public: } } - IOState getState() const + bool isWaitingForRead() const { - return d_currentState; + return d_isWaitingForRead; + } + + bool isWaitingForWrite() const + { + return d_isWaitingForWrite; } void setSocket(int fd) @@ -54,22 +59,66 @@ public: update(IOState::Done); } + std::string getState() const + { + std::string result("--"); + result.reserve(2); + if (isWaitingForRead()) { + result.at(0) = 'R'; + } + if (isWaitingForWrite()) { + result.at(1) = 'W'; + } + return result; + } + + void add(IOState iostate, FDMultiplexer::callbackfunc_t callback, FDMultiplexer::funcparam_t callbackData, boost::optional ttd) + { + DEBUGLOG("in "<<__PRETTY_FUNCTION__<<" for fd "< ttd = boost::none) { - DEBUGLOG("in "<<__PRETTY_FUNCTION__<<" for fd "< getNextProtocol() const override + { + return std::vector(); + } + LibsslTLSVersion getTLSVersion() const override { return LibsslTLSVersion::TLS13; diff --git a/pdns/libssl.cc b/pdns/libssl.cc index b667d27d01..ccc287ed6d 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -793,6 +793,40 @@ std::unique_ptr libssl_set_key_log_file(std::unique_ptr& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg) +{ +#ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB + SSL_CTX_set_next_proto_select_cb(ctx.get(), cb, arg); +#endif +} + +void libssl_set_alpn_select_callback(std::unique_ptr& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg) +{ +#ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB + SSL_CTX_set_alpn_select_cb(ctx.get(), cb, arg); +#endif +} + +bool libssl_set_alpn_protos(std::unique_ptr& ctx, const std::vector>& protos) +{ +#ifdef HAVE_SSL_CTX_SET_ALPN_PROTOS + std::vector wire; + for (const auto& proto : protos) { + if (proto.size() > std::numeric_limits::max()) { + throw std::runtime_error("Invalid ALPN value"); + } + uint8_t length = proto.size(); + wire.push_back(length); + wire.insert(wire.end(), proto.begin(), proto.end()); + } + return SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) == 0; +#else + return false; +#endif +} + + std::string libssl_get_error_string() { BIO *mem = BIO_new(BIO_s_mem()); diff --git a/pdns/libssl.hh b/pdns/libssl.hh index b090afa7b1..2af0f4ef89 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -126,6 +126,13 @@ std::unique_ptr libssl_init_server_context(const TLS std::unique_ptr libssl_set_key_log_file(std::unique_ptr& ctx, const std::string& logFile); +/* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */ +void libssl_set_npn_select_callback(std::unique_ptr& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg); +/* called in a server context, to select an ALPN value advertised by the client if any */ +void libssl_set_alpn_select_callback(std::unique_ptr& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg); +/* set the supported ALPN protos in client context */ +bool libssl_set_alpn_protos(std::unique_ptr& ctx, const std::vector>& protos); + std::string libssl_get_error_string(); #endif /* HAVE_LIBSSL */ diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 4eb80627ab..e957d66bbe 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -151,7 +151,12 @@ public: return IOState::NeedWrite; } else if (error == SSL_ERROR_SYSCALL) { - throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno))); + if (errno == 0) { + throw std::runtime_error("TLS connection closed by remote end"); + } + else { + throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno))); + } } else if (error == SSL_ERROR_ZERO_RETURN) { throw std::runtime_error("TLS connection closed by remote end"); @@ -401,6 +406,29 @@ public: return std::string(); } + std::vector getNextProtocol() const override + { + std::vector result; + if (!d_conn) { + return result; + } + + const unsigned char* alpn = nullptr; + unsigned int alpnLen = 0; +#ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED + SSL_get0_next_proto_negotiated(d_conn.get(), &alpn, &alpnLen); +#endif +#ifdef HAVE_SSL_GET0_ALPN_SELECTED + if (alpn == nullptr) { + SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen); + } +#endif + if (alpn != nullptr && alpnLen > 0) { + result.insert(result.end(), alpn, alpn + alpnLen); + } + return result; + } + LibsslTLSVersion getTLSVersion() const override { auto proto = SSL_version(d_conn.get()); @@ -668,9 +696,74 @@ public: return "openssl"; } + bool setALPNProtos(const std::vector>& protos) override + { + if (d_feContext && d_feContext->d_tlsCtx) { + d_alpnProtos = protos; + libssl_set_alpn_select_callback(d_feContext->d_tlsCtx, alpnServerSelectCallback, this); + return true; + } + if (d_tlsCtx) { + return libssl_set_alpn_protos(d_tlsCtx, protos); + } + return false; + } + + bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override + { + d_nextProtocolSelectCallback = cb; + libssl_set_npn_select_callback(d_tlsCtx, npnSelectCallback, this); + return true; + } + private: + /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */ + static int npnSelectCallback(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) + { + if (!arg) { + return SSL_TLSEXT_ERR_ALERT_WARNING; + } + OpenSSLTLSIOCtx* obj = reinterpret_cast(arg); + if (obj->d_nextProtocolSelectCallback) { + return (*obj->d_nextProtocolSelectCallback)(out, outlen, in, inlen) ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_ALERT_WARNING; + } + + return SSL_TLSEXT_ERR_OK; + } + + static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg) + { + if (!arg) { + return SSL_TLSEXT_ERR_ALERT_WARNING; + } + OpenSSLTLSIOCtx* obj = reinterpret_cast(arg); + + size_t pos = 0; + while (pos < inlen) { + size_t protoLen = in[pos]; + pos++; + if (protoLen > (inlen - pos)) { + /* something is very wrong */ + return SSL_TLSEXT_ERR_ALERT_WARNING; + } + + for (const auto& tentative : obj->d_alpnProtos) { + if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) { + *out = in + pos; + *outlen = protoLen; + return SSL_TLSEXT_ERR_OK; + } + } + pos += protoLen; + } + + return SSL_TLSEXT_ERR_NOACK; + } + + std::vector> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent std::shared_ptr d_feContext; - std::unique_ptr d_tlsCtx; // client context + std::unique_ptr d_tlsCtx; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx + bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr}; }; #endif /* HAVE_LIBSSL */ @@ -1226,6 +1319,20 @@ public: return std::string(); } + std::vector getNextProtocol() const override + { + std::vector result; + if (!d_conn) { + return result; + } + gnutls_datum_t next; + if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) { + return result; + } + result.insert(result.end(), next.data, next.data + next.size); + return result; + } + LibsslTLSVersion getTLSVersion() const override { auto proto = gnutls_protocol_get_version(d_conn.get()); @@ -1285,6 +1392,19 @@ public: } } + bool setALPNProtos(const std::vector>& protos) + { + std::vector values; + values.reserve(protos.size()); + for (const auto& proto : protos) { + gnutls_datum_t value; + value.data = const_cast(proto.data()); + value.size = proto.size(); + values.push_back(value); + } + return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), GNUTLS_ALPN_MANDATORY); + } + private: std::shared_ptr d_ticketsKey; std::unique_ptr d_conn; @@ -1406,12 +1526,20 @@ public: ticketsKey = *(d_ticketsKey.read_lock()); } - return std::make_unique(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets); + auto connection = std::make_unique(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets); + if (!d_protos.empty()) { + connection->setALPNProtos(d_protos); + } + return connection; } std::unique_ptr getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override { - return std::make_unique(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts); + auto connection = std::make_unique(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts); + if (!d_protos.empty()) { + connection->setALPNProtos(d_protos); + } + return connection; } void rotateTicketsKey(time_t now) override @@ -1457,8 +1585,19 @@ public: return "gnutls"; } + bool setALPNProtos(const std::vector>& protos) override + { +#ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS + d_protos = protos; + return true; +#else + return false; +#endif + } + private: std::unique_ptr d_creds; + std::vector> d_protos; gnutls_priority_t d_priorityCache{nullptr}; SharedLockGuarded> d_ticketsKey{nullptr}; bool d_enableTickets{true}; @@ -1469,6 +1608,17 @@ private: #endif /* HAVE_DNS_OVER_TLS */ +bool setupDoTProtocolNegotiation(std::shared_ptr& ctx) +{ + if (ctx == nullptr) { + return false; + } + /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */ + const std::vector> dotAlpns = {{'d', 'o', 't'}}; + ctx->setALPNProtos(dotAlpns); + return true; +} + bool TLSFrontend::setupTLS() { #ifdef HAVE_DNS_OVER_TLS @@ -1478,6 +1628,7 @@ bool TLSFrontend::setupTLS() #ifdef HAVE_GNUTLS if (d_provider == "gnutls") { newCtx = std::make_shared(*this); + setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); return true; } @@ -1485,6 +1636,7 @@ bool TLSFrontend::setupTLS() #ifdef HAVE_LIBSSL if (d_provider == "openssl") { newCtx = std::make_shared(*this); + setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); return true; } @@ -1498,6 +1650,7 @@ bool TLSFrontend::setupTLS() #endif /* HAVE_GNUTLS */ #endif /* HAVE_LIBSSL */ + setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); #endif /* HAVE_DNS_OVER_TLS */ return true; diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index e948b130ae..fcde2c62ae 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -33,6 +33,7 @@ public: virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0; virtual bool hasBufferedData() const = 0; virtual std::string getServerNameIndication() const = 0; + virtual std::vector getNextProtocol() const = 0; virtual LibsslTLSVersion getTLSVersion() const = 0; virtual bool hasSessionBeenResumed() const = 0; virtual std::unique_ptr getSession() = 0; @@ -111,6 +112,18 @@ public: virtual size_t getTicketsKeysCount() = 0; virtual std::string getName() const = 0; + /* set the advertised ALPN protocols, in client or server context */ + virtual bool setALPNProtos(const std::vector>& protos) + { + return false; + } + + /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */ + virtual bool setNextProtocolSelectCallback(bool(*)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) + { + return false; + } + protected: std::atomic_flag d_rotatingTicketsKey; std::atomic d_ticketsKeyNextRotation{0}; @@ -465,6 +478,14 @@ public: return std::string(); } + std::vector getNextProtocol() const + { + if (d_conn) { + return d_conn->getNextProtocol(); + } + return std::vector(); + } + LibsslTLSVersion getTLSVersion() const { if (d_conn) { @@ -528,3 +549,4 @@ struct TLSContextParameters }; std::shared_ptr getTLSContext(const TLSContextParameters& params); +bool setupDoTProtocolNegotiation(std::shared_ptr& ctx);