From: Remi Gacogne Date: Mon, 19 Jul 2021 16:06:53 +0000 (+0200) Subject: Working DoH between dnsdist and the backend! X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~34 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9eb5394aa3661dca8669b9ca6ffea0b3c978e415;p=thirdparty%2Fpdns.git Working DoH between dnsdist and the backend! --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 69663e5191..b1f8465b1b 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -530,6 +530,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret->d_tlsCtx = getTLSContext(tlsParams); } + if (vars.count("dohPath")) { + ret->d_dohPath = boost::get(vars.at("dohPath")); + } + /* this needs to be done _AFTER_ the order has been set, since the server are kept ordered inside the pool */ auto localPools = g_pools.getCopy(); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index de23e7b933..490a3d52c8 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -51,6 +51,7 @@ #include "dnsdist-ecs.hh" #include "dnsdist-healthchecks.hh" #include "dnsdist-lua.hh" +#include "dnsdist-nghttp2.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" #include "dnsdist-secpoll.hh" @@ -1505,12 +1506,8 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } auto cpq = std::make_unique(std::move(query), std::move(ids), ss); - if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { - return ; - } - else { - return; - } + ss->passCrossProtocolQuery(std::move(cpq)); + return; } unsigned int idOffset = (ss->idOffset++) % ss->idStates.size(); @@ -2171,8 +2168,6 @@ static void sighandler(int sig) } #endif -#include "dnsdist-nghttp2.hh" - int main(int argc, char** argv) { try { @@ -2553,6 +2548,8 @@ int main(int argc, char** argv) g_tcpclientthreads = std::make_unique(*g_maxTCPClientThreads); + initDoHWorkers(); + for (auto& t : todo) { t(); } @@ -2640,8 +2637,6 @@ int main(int argc, char** argv) secpollthread.detach(); } - sendHTTP2Query(); - if(g_cmdLine.beSupervised) { #ifdef HAVE_SYSTEMD sd_notify(0, "READY=1"); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 85fa96ea2e..4ba0b80876 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -649,6 +649,8 @@ struct ClientState } }; +struct CrossProtocolQuery; + struct DownstreamState { typedef std::function(const DNSName&, uint16_t, uint16_t, dnsheader*)> checkfunc_t; @@ -662,6 +664,7 @@ struct DownstreamState std::vector sockets; const std::string sourceItfName; std::string d_tlsSubjectName; + std::string d_dohPath; std::mutex connectLock; LockGuarded> mplexer{nullptr}; std::shared_ptr d_tlsCtx{nullptr}; @@ -823,6 +826,8 @@ struct DownstreamState return d_tcpOnly || d_tlsCtx != nullptr; } + bool passCrossProtocolQuery(std::unique_ptr&& cpq); + private: std::string name; std::string nameWithAddr; diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 5928ca9e8b..03ae3e1110 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -242,6 +242,7 @@ testrunner_SOURCES = \ dnsdist-lua-ffi-interface.h dnsdist-lua-ffi-interface.inc \ dnsdist-lua-ffi.cc dnsdist-lua-ffi.hh \ dnsdist-lua-vars.cc \ + dnsdist-nghttp2.cc dnsdist-nghttp2.hh \ dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ dnsdist-rings.cc dnsdist-rings.hh \ @@ -321,6 +322,7 @@ testrunner_LDFLAGS = \ $(AM_LDFLAGS) \ $(PROGRAM_LDFLAGS) \ $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) \ + -lnghttp2 \ -pthread testrunner_LDADD = \ diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index d0ffe2eecb..ceef5574b4 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -21,8 +21,21 @@ */ #include "dnsdist.hh" +#include "dnsdist-nghttp2.hh" +#include "dnsdist-tcp.hh" #include "dolog.hh" + +bool DownstreamState::passCrossProtocolQuery(std::unique_ptr&& cpq) +{ + if (d_dohPath.empty()) { + return g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq)); + } + else { + return g_dohClientThreads && g_dohClientThreads->passCrossProtocolQueryToThread(std::move(cpq)); + } +} + bool DownstreamState::reconnect() { std::unique_lock tl(connectLock, std::try_to_lock); diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 9e95676b62..592a21f6d4 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -1,40 +1,347 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ #include +#include "dnsdist-nghttp2.hh" +#include "dnsdist-tcp.hh" +#include "dnsdist-tcp-downstream.hh" + +#include "dolog.hh" #include "iputils.hh" #include "libssl.hh" #include "noinitvector.hh" #include "tcpiohandler.hh" +#include "threadname.hh" #include "sstuff.hh" #warning remove me #include "dnswriter.hh" -struct MyUserData +std::atomic g_dohStatesDumpRequested{0}; +std::unique_ptr g_dohClientThreads{nullptr}; + +class DoHConnectionToBackend: public TCPConnectionToBackend +{ +public: + DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now); + + void handleTimeout(const struct timeval& now, bool write) override + { +#warning FIXME: we should notify the owners of pending queries / responses + } + + void queueQuery(std::shared_ptr& sender, TCPQuery&& query) override; + + std::string toString() const override + { + ostringstream o; + //o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket"); + return o.str(); + } + +private: + static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data); + static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data); + static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data); + static int on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data); + static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data); + static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data); + static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data); + static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param); + static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param); + static void handleIO(std::shared_ptr& conn, const struct timeval& now); + + class PendingRequest + { + public: + std::shared_ptr d_sender{nullptr}; + TCPQuery d_query; + PacketBuffer d_buffer; + bool d_finished{false}; + }; + void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback); + void stopIO(); + void handleResponse(PendingRequest&& request); + + //std::deque d_pendingQueries; + + std::unique_ptr d_session{nullptr, nghttp2_session_del}; + std::unordered_map d_currentStreams; + PacketBuffer d_out; + PacketBuffer d_in; + size_t d_outPos{0}; + size_t d_inPos{0}; +}; + + +void DoHConnectionToBackend::handleResponse(PendingRequest&& request) +{ + cerr<<"handle response!"<handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this())); +} + +#define MAKE_NV(NAME, VALUE, VALUELEN) \ + { \ + (uint8_t *)NAME, (uint8_t *)VALUE, sizeof(NAME) - 1, VALUELEN, \ + NGHTTP2_NV_FLAG_NONE \ + } + +#define MAKE_NV2(NAME, VALUE) \ + { \ + (uint8_t *)NAME, (uint8_t *)VALUE, sizeof(NAME) - 1, sizeof(VALUE) - 1, \ + NGHTTP2_NV_FLAG_NONE \ + } + +void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) +{ + /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE + to avoid a copy and lowercasing as long as we take care of making sure that the data will outlive the request + and that it is already lowercased. */ + auto payloadSize = std::to_string(query.d_buffer.size()); + d_currentQuery = std::move(query); + const nghttp2_nv hdrs[] = { + MAKE_NV2(":method", "POST"), + MAKE_NV2(":scheme", "https"), + MAKE_NV(":authority", d_ds->d_tlsSubjectName.c_str(), d_ds->d_tlsSubjectName.size()), + MAKE_NV(":path", d_ds->d_dohPath.c_str(), d_ds->d_dohPath.size()), + MAKE_NV2("accept", "application/dns-message"), + MAKE_NV2("content-type", "application/dns-message"), + MAKE_NV("content-length", payloadSize.c_str(), payloadSize.size()), + MAKE_NV2("user-agent", "nghttp2-" NGHTTP2_VERSION "/dnsdist") + }; + + /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set + */ + cerr<<"Remote size window is "< ssize_t + { + cerr<<"in data provider"<(user_data); + if (userData->d_inPos >= userData->d_currentQuery.d_buffer.size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + cerr<<"EOF"<d_currentQuery.d_buffer.size()- userData->d_inPos; + size_t toCopy = length > remaining ? remaining : length; + memcpy(buf, &userData->d_currentQuery.d_buffer.at(userData->d_inPos), toCopy); + userData->d_inPos += toCopy; + cerr< session{nullptr, nghttp2_session_del}; - std::unique_ptr handler; - PacketBuffer out; - PacketBuffer in; - size_t outPos{0}; - size_t inPos{0}; +public: + DoHClientThreadData(): mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) + { + } + + std::unique_ptr mplexer{nullptr}; }; -static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) { +void DoHConnectionToBackend::handleIO(std::shared_ptr& conn, const struct timeval& now) +{ +} + +void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + cerr<<"in "<<__PRETTY_FUNCTION__<<", param is "<>(param); + if (fd != conn->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); + } + + IOStateGuard ioGuard(conn->d_ioState); + do { + conn->d_inPos = 0; + conn->d_in.resize(conn->d_in.size() + 512); + cerr<<"trying to read "<d_in.size()<d_handler->tryRead(conn->d_in, conn->d_inPos, conn->d_in.size(), true); + // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size()); + cerr<<"got a "<<(int)newState<<" state and "<d_inPos<<" bytes"<d_in.resize(conn->d_inPos); + if (newState == IOState::Done) { + auto readlen = nghttp2_session_mem_recv(conn->d_session.get(), conn->d_in.data(), conn->d_inPos); + cerr<<"nghttp2_session_mem_recv returned "< 0 && static_cast(readlen) < conn->d_inPos) { + cerr<<"Fatal error: "<d_session.get()); + cerr<<"nghttp2_session_send returned "<updateIO(IOState::NeedWrite, handleReadableIOCallback); + } + ioGuard.release(); + break; + } + } + catch (const std::exception& e) { + cerr<<"got exception "<>(param); + if (fd != conn->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); + } + IOStateGuard ioGuard(conn->d_ioState); + + cerr<<"trying to write "<d_out.size()-conn->d_outPos<d_handler->tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); + cerr<<"got a "<<(int)newState<<" state, "<d_out.size()-conn->d_inPos<<" bytes remaining"<updateIO(IOState::NeedRead, handleWritableIOCallback); + } + else if (newState == IOState::Done) { + conn->d_out.clear(); + conn->d_outPos = 0; + conn->stopIO(); + conn->updateIO(IOState::NeedRead, handleReadableIOCallback); + } + ioGuard.release(); + } + catch (const std::exception& e) { + cerr<<"got exception "<reset(); +} + +void DoHConnectionToBackend::updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback) +{ + struct timeval now; + gettimeofday(&now, nullptr); + boost::optional ttd{boost::none}; + if (newState == IOState::NeedRead) { + ttd = getBackendReadTTD(now); + } + else if (isFresh() && d_queries == 0) { + /* first write just after the non-blocking connect */ + ttd = getBackendConnectTTD(now); + } + else { + ttd = getBackendWriteTTD(now); + } + + auto shared = std::dynamic_pointer_cast(shared_from_this()); + if (shared) { + if (newState == IOState::NeedRead) { + d_ioState->update(newState, callback, shared, ttd); + } + else if (newState == IOState::NeedWrite) { + d_ioState->update(newState, callback, shared, ttd); + } + } +} + +ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) { cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); - userData->out.insert(userData->out.end(), data, data + length); - userData->handler->write(userData->out.data() + userData->outPos, userData->out.size() - userData->outPos, timeval{2, 0}); - userData->out.clear(); + DoHConnectionToBackend* userData = reinterpret_cast(user_data); + bool bufferWasEmpty = userData->d_out.empty(); + userData->d_out.insert(userData->d_out.end(), data, data + length); + + if (bufferWasEmpty) { + auto state = userData->d_handler->tryWrite(userData->d_out, userData->d_outPos, userData->d_out.size()); + if (state == IOState::Done) { + userData->d_out.clear(); +#warning FIXME from now on we need to read, as we might get an answer + cerr<<"FIXME now we need to read!"<addToIOState(IOState::NeedRead); + //} + } + else { +#warning write me should be addIO() instead, perhaps? + cerr<<"now we need to wait for a writable (or readable) socket"<updateIO(state, handleWritableIOCallback); + } + } + return length; } -static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) { +int DoHConnectionToBackend::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) { cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); + DoHConnectionToBackend* conn = reinterpret_cast(user_data); + cerr<<"Frame type is "<hd.type)<hd.type) { case NGHTTP2_HEADERS: + cerr<<"got headers"<headers.cat == NGHTTP2_HCAT_RESPONSE) { cerr<<"All headers received"<settings.iv[idx].settings_id<<" "<settings.iv[idx].value<hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + auto stream = conn->d_currentStreams.find(frame->hd.stream_id); + if (stream != conn->d_currentStreams.end()) { + cerr<<"Stream "<hd.stream_id<<" is now finished"<second.d_finished = true; + + auto request = std::move(stream->second); + conn->d_currentStreams.erase(stream->first); + conn->handleResponse(std::move(request)); + } + else { + cerr<<"Stream "<hd.stream_id<<" NOT FOUND"<(user_data); - cerr<<"Got data of size "<(data), len)<(user_data); + cerr<<"Got data of size "<d_currentStreams.find(stream_id); + if (stream == conn->d_currentStreams.end()) { + cerr<<"Unable to match the stream ID "<second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len); + if (stream->second.d_finished) { + cerr<<"we now have the full response!"<second); + conn->d_currentStreams.erase(stream->first); + conn->handleResponse(std::move(request)); + cerr<(data), len)<(user_data); + //DoHConnectionToBackend* userData = reinterpret_cast(user_data); cerr<<"Stream "<(user_data); + //DoHConnectionToBackend* userData = reinterpret_cast(user_data); switch (frame->hd.type) { case NGHTTP2_HEADERS: @@ -92,9 +440,9 @@ static int on_header_callback(nghttp2_session* session, const nghttp2_frame* fra return 0; } -static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) { +int DoHConnectionToBackend::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) { cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); + //DoHConnectionToBackend* userData = reinterpret_cast(user_data); switch (frame->hd.type) { case NGHTTP2_HEADERS: @@ -106,25 +454,34 @@ static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_fra return 0; } -static void doReadData(MyUserData& userData) +int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data) { + cerr<<"in "<<__PRETTY_FUNCTION__<(user_data); + + return 0; +} + +#if 0 +static void doReadData(DoHConnectionToBackend& userData) { do { size_t pos = 0; - userData.in.resize(512); - cerr<<"trying to read "<read(userData.in.data(), userData.in.size(), timeval{2, 0}, timeval{2, 0}, true); - // userData.handler->tryRead(userData.in, pos, userData.in.size()); + pos = userData.d_handler->read(userData.d_in.data(), userData.d_in.size(), timeval{2, 0}, timeval{2, 0}, true); + // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size()); cerr<<"got "< 0) { - auto readlen = nghttp2_session_mem_recv(userData.session.get(), userData.in.data(), pos); + auto readlen = nghttp2_session_mem_recv(userData.d_session.get(), userData.d_in.data(), pos); cerr<<"nghttp2_session_mem_recv returned "<(host, sock.releaseHandle(), timeval{2, 0}, tlsCtx, time(nullptr)); - userData.handler->connect(true, remote, timeval{2, 0}); + DoHConnectionToBackend userData; + userData.d_handler = std::make_unique(host, sock.releaseHandle(), timeval{2, 0}, tlsCtx, time(nullptr)); + userData.d_handler->connect(true, remote, timeval{2, 0}); /* check ALPN: SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); @@ -203,7 +549,7 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); return; } - userData.session = std::unique_ptr(sess, nghttp2_session_del); + userData.d_session = std::unique_ptr(sess, nghttp2_session_del); sess = nullptr; callbacks.reset(); @@ -216,20 +562,20 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16*1024*1024} }; /* client 24 bytes magic string will be sent by nghttp2 library */ - int rv = nghttp2_submit_settings(userData.session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv)); + int rv = nghttp2_submit_settings(userData.d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv)); if (rv != 0) { cerr<<"Could not submit SETTINGS: "< pw(userData.in, DNSName("doh.dnsdist.org."), QType::A, QClass::IN, 0); + GenericDNSPacketWriter pw(userData.d_in, DNSName("doh.dnsdist.org."), QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; pw.commit(); /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE to avoid a copy and lowercasing as long as we take care of making sure that the data will outlive the request and that it is already lowercased. */ - auto payloadSize = std::to_string(userData.in.size()); + auto payloadSize = std::to_string(userData.d_in.size()); const nghttp2_nv hdrs[] = { MAKE_NV2(":method", "POST"), MAKE_NV2(":scheme", "https"), @@ -243,40 +589,380 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); /* f data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set */ - cerr<<"Remote size window is "< ssize_t { cerr<<"in data provider"<(user_data); - if (userData->inPos >= userData->in.size()) { + auto userData = reinterpret_cast(user_data); + if (userData->d_inPos >= userData->d_in.size()) { *data_flags |= NGHTTP2_DATA_FLAG_EOF; cerr<<"EOF"<in.size()- userData->inPos; + size_t remaining = userData->d_in.size()- userData->d_inPos; size_t toCopy = length > remaining ? remaining : length; - memcpy(buf, &userData->in.at(userData->inPos), toCopy); - userData->inPos += toCopy; + memcpy(buf, &userData->d_in.at(userData->d_inPos), toCopy); + userData->d_inPos += toCopy; cerr< ds, std::unique_ptr& mplexer, const struct timeval& now): TCPConnectionToBackend(ds, mplexer, now) +{ + // inherit most of the stuff from the TCPConnectionToBackend() + + /* check ALPN: +SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen); +#if OPENSSL_VERSION_NUMBER >= 0x10002000L + if (alpn == NULL) { + SSL_get0_alpn_selected(ssl, &alpn, &alpnlen); + } +#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L + + if (alpn == NULL || alpnlen != 2 || memcmp("h2", alpn, 2) != 0) { + fprintf(stderr, "h2 is not negotiated\n"); + delete_http2_session_data(session_data); + return; + } + */ + d_ioState = make_unique(*d_mplexer, d_handler->getDescriptor()); + + nghttp2_session_callbacks* cbs = nullptr; + if (nghttp2_session_callbacks_new(&cbs) != 0) { + cerr<<"unable to create a callback object for a new HTTP/2 session"< callbacks(cbs, nghttp2_session_callbacks_del); + cbs = nullptr; + + nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback); + nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback); + nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback); + nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback); + + nghttp2_session* sess = nullptr; + if (nghttp2_session_client_new(&sess, callbacks.get(), this) != 0) { + cerr<<"Coult not allocate a new HTTP/2 session"<(sess, nghttp2_session_del); + sess = nullptr; + + callbacks.reset(); + +#warning we should make the 100 configurable here, as we might want a lower number before receiving the one actually supported by the server +#warning we should also make the window size configurable, but 16M is a nice default + nghttp2_settings_entry iv[] = { + {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100}, + {NGHTTP2_SETTINGS_ENABLE_PUSH, 0}, + {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16*1024*1024} + }; + /* client 24 bytes magic string will be sent by nghttp2 library */ + int rv = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv)); + if (rv != 0) { + cerr<<"Could not submit SETTINGS: "< getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now); + static void releaseDownstreamConnection(std::shared_ptr&& conn); + static void cleanupClosedConnections(struct timeval now); + static size_t clear(); + + static void setMaxCachedConnectionsPerDownstream(size_t max) + { + s_maxCachedConnectionsPerDownstream = max; + } + + static void setCleanupInterval(uint16_t interval) + { + s_cleanupInterval = interval; + } + +private: + static thread_local map>> t_downstreamConnections; + static size_t s_maxCachedConnectionsPerDownstream; + static time_t s_nextCleanup; + static uint16_t s_cleanupInterval; +}; + +struct DoHClientCollection::DoHWorkerThread +{ + DoHWorkerThread() + { + } + + DoHWorkerThread(int crossProtocolPipe): d_crossProtocolQueryPipe(crossProtocolPipe) + { + } + + DoHWorkerThread(DoHWorkerThread&& rhs): d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe) + { + rhs.d_crossProtocolQueryPipe = -1; + } + + DoHWorkerThread& operator=(DoHWorkerThread&& rhs) + { + if (d_crossProtocolQueryPipe != -1) { + close(d_crossProtocolQueryPipe); + } + + d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe; + rhs.d_crossProtocolQueryPipe = -1; + + return *this; + } + + DoHWorkerThread(const DoHWorkerThread& rhs) = delete; + DoHWorkerThread& operator=(const DoHWorkerThread&) = delete; + + ~DoHWorkerThread() + { + if (d_crossProtocolQueryPipe != -1) { + close(d_crossProtocolQueryPipe); + } + } + + int d_crossProtocolQueryPipe{-1}; +}; + +DoHClientCollection::DoHClientCollection(size_t maxThreads): d_clientThreads(maxThreads), d_maxThreads(maxThreads) +{ +} + +bool DoHClientCollection::passCrossProtocolQueryToThread(std::unique_ptr&& cpq) +{ + if (d_numberOfThreads == 0) { + throw std::runtime_error("No DoH worker thread yet"); + } + + uint64_t pos = d_pos++; + auto pipe = d_clientThreads.at(pos % d_numberOfThreads).d_crossProtocolQueryPipe; + auto tmp = cpq.release(); + + if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { + delete tmp; + tmp = nullptr; + return false; + } + + return true; +} + +std::shared_ptr DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now) +{ + return std::make_shared(ds, mplexer, now); +} + +static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) +{ + auto threadData = boost::any_cast(param); + CrossProtocolQuery* tmp{nullptr}; + + ssize_t got = read(pipefd, &tmp, sizeof(tmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(tmp)) { + throw std::runtime_error("Partial read while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + try { + struct timeval now; + gettimeofday(&now, nullptr); + + std::shared_ptr tqs = tmp->getTCPQuerySender(); + auto query = std::move(tmp->query); + auto downstreamServer = std::move(tmp->downstream); + delete tmp; + tmp = nullptr; + + auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now); + +#warning what about the proxy protocol payload, here, do we need to remove it? we likely need to handle forward-for headers? + downstream->queueQuery(tqs, std::move(query)); + } + catch (...) { + delete tmp; + tmp = nullptr; + throw; + } +} + +static void dohClientThread(int crossProtocolPipeFD) +{ + setThreadName("dnsdist/dohClie"); + + DoHClientThreadData data; + + data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data); + + struct timeval now; + gettimeofday(&now, nullptr); + time_t lastTimeoutScan = now.tv_sec; + + for (;;) { + data.mplexer->run(&now); + + if (now.tv_sec > lastTimeoutScan) { + lastTimeoutScan = now.tv_sec; + auto expiredReadConns = data.mplexer->getTimeouts(now, false); + for (const auto& cbData : expiredReadConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (read) from remote DoH backend %s", conn->getBackendName()); + conn->handleTimeout(now, false); + } + } + + auto expiredWriteConns = data.mplexer->getTimeouts(now, true); + for (const auto& cbData : expiredWriteConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (write) from remote DoH backend %s", conn->getBackendName()); + conn->handleTimeout(now, true); + } + } + + if (g_dohStatesDumpRequested > 0) { + /* just to keep things clean in the output, debug only */ + static std::mutex s_lock; + std::lock_guard lck(s_lock); + if (g_dohStatesDumpRequested > 0) { + /* no race here, we took the lock so it can only be increased in the meantime */ + --g_dohStatesDumpRequested; + errlog("Dumping the DoH client states, as requested:"); + data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd) + { + struct timeval lnow; + gettimeofday(&lnow, nullptr); + if (ttd.tv_sec > 0) { + errlog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec)); + } + else { + errlog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write")); + } + + if (param.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(param); + errlog(" - %s", conn->toString()); + } + else if (param.type() == typeid(DoHClientThreadData*)) { + errlog(" - Worker thread pipe"); + } + }); + } + } + } + } +} + +void DoHClientCollection::addThread() +{ + auto preparePipe = [](int fds[2], const std::string& type) -> bool { + if (pipe(fds) < 0) { + errlog("Error creating the DoH thread %s pipe: %s", type, stringerror()); + return false; + } + + if (!setNonBlocking(fds[0])) { + int err = errno; + close(fds[0]); + close(fds[1]); + errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; + } + + if (!setNonBlocking(fds[1])) { + int err = errno; + close(fds[0]); + close(fds[1]); + errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; + } + + if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) { + setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize); + } + + return true; + }; + + int crossProtocolFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolFDs, "cross-protocol")) { + return; + } + + vinfolog("Adding DoH Client thread"); + + { + std::lock_guard lock(d_mutex); + + if (d_numberOfThreads >= d_clientThreads.size()) { + vinfolog("Adding a new DoH client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of DoH client threads with setMaxDoHClientThreads() in the configuration.", d_numberOfThreads.load(), d_clientThreads.size()); + close(crossProtocolFDs[0]); + close(crossProtocolFDs[1]); + return; + } + + /* from now on this side of the pipe will be managed by that object, + no need to worry about it */ + DoHWorkerThread worker(crossProtocolFDs[1]); + try { + std::thread t1(dohClientThread, crossProtocolFDs[0]); + t1.detach(); + } + catch (const std::runtime_error& e) { + /* the thread creation failed, don't leak */ + errlog("Error creating a DoH thread: %s", e.what()); + close(crossProtocolFDs[0]); + return; + } + + d_clientThreads.at(d_numberOfThreads) = std::move(worker); + ++d_numberOfThreads; + } +} + +bool initDoHWorkers() +{ +#warning FIXME: number of DoH threads + g_dohClientThreads = std::make_unique(1); + g_dohClientThreads->addThread(); + return true; +} diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.hh b/pdns/dnsdistdist/dnsdist-nghttp2.hh index 49ce323616..9713735823 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.hh +++ b/pdns/dnsdistdist/dnsdist-nghttp2.hh @@ -1,3 +1,63 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ #pragma once -void sendHTTP2Query(); +#include +#include +#include + +#include "stat_t.hh" + +struct CrossProtocolQuery; + +class DoHClientCollection +{ +public: + DoHClientCollection(size_t maxThreads); + + bool hasReachedMaxThreads() const + { + return d_numberOfThreads >= d_maxThreads; + } + + uint64_t getThreadsCount() const + { + return d_numberOfThreads; + } + + bool passCrossProtocolQueryToThread(std::unique_ptr&& cpq); + void addThread(); + +private: + struct DoHWorkerThread; + + std::mutex d_mutex; + std::vector d_clientThreads; + pdns::stat_t d_numberOfThreads{0}; + pdns::stat_t d_pos{0}; + const uint64_t d_maxThreads{0}; +}; + +extern std::unique_ptr g_dohClientThreads; +extern std::atomic g_dohStatesDumpRequested; + +bool initDoHWorkers(); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index d480b38935..2201f2df4c 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -225,8 +225,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c conn->d_pendingResponses.clear(); conn->d_currentPos = 0; - if (conn->d_state == State::doingHandshake || - conn->d_state == State::sendingQueryToBackend) { + if (conn->d_state == State::sendingQueryToBackend) { iostate = IOState::NeedWrite; // resume sending query } @@ -295,7 +294,7 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t } struct timeval now; - gettimeofday(&now, 0); + gettimeofday(&now, nullptr); handleIO(conn, now); } @@ -309,7 +308,7 @@ void TCPConnectionToBackend::queueQuery(std::shared_ptr& sender, throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries"); } - // if we are not already sending a query or in the middle of reading a response (so idle or doingHandshake), + // if we are not already sending a query or in the middle of reading a response (so idle), // start sending the query if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) { DEBUGLOG("Sending new query to backend right away"); @@ -615,7 +614,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrgetNameWithAddr()); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh index 3745de9657..9301ad70d8 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -15,7 +15,7 @@ public: reconnect(); } - ~TCPConnectionToBackend(); + virtual ~TCPConnectionToBackend(); int getHandle() const { @@ -112,8 +112,8 @@ public: return ds == d_ds; } - void queueQuery(std::shared_ptr& sender, TCPQuery&& query); - void handleTimeout(const struct timeval& now, bool write); + virtual void queueQuery(std::shared_ptr& sender, TCPQuery&& query); + virtual void handleTimeout(const struct timeval& now, bool write); void release(); void setProxyProtocolValuesSent(std::unique_ptr>&& proxyProtocolValuesSent); @@ -123,17 +123,17 @@ public: return d_lastDataReceivedTime; } - std::string toString() const + virtual std::string toString() const { ostringstream o; o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<& conn, const struct timeval& now); @@ -143,7 +143,7 @@ private: static bool isXFRFinished(const TCPResponse& response, TCPQuery& query); IOState handleResponse(std::shared_ptr& conn, const struct timeval& now); - uint16_t getQueryIdFromResponse(); + uint16_t getQueryIdFromResponse() const; bool reconnect(); void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason); bool needProxyProtocolPayload() const @@ -197,6 +197,7 @@ private: } PacketBuffer d_responseBuffer; +#warning we do not need this and could append to the outgoing buffer right away but is it better? std::deque d_pendingQueries; std::unordered_map d_pendingResponses; std::unique_ptr& d_mplexer; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 99da15c094..e48c599f64 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -21,6 +21,10 @@ */ #pragma once +#include +#include "iputils.hh" +#include "dnsdist.hh" + struct ConnectionInfo { ConnectionInfo(ClientState* cs_) :