From: Remi Gacogne Date: Fri, 30 Jun 2023 15:49:35 +0000 (+0200) Subject: dnsdist: Add unit and regression tests for incoming DoH w/ nghttp2 X-Git-Tag: rec-5.0.0-alpha1~19^2~21 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=c02b7e139240e28260dea2313e8af933e5dda332;p=thirdparty%2Fpdns.git dnsdist: Add unit and regression tests for incoming DoH w/ nghttp2 It is quite likely that the underlying TLS layer has buffered some data already, so we need to consume it before trying to poll the socket. --- diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh index 41166de9f3..f0a1adc767 100644 --- a/pdns/dnsdist-doh-common.hh +++ b/pdns/dnsdist-doh-common.hh @@ -77,6 +77,10 @@ struct DOHFrontend DOHFrontend() { } + DOHFrontend(std::shared_ptr tlsCtx): + d_tlsContext(std::move(tlsCtx)) + { + } virtual ~DOHFrontend() { diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index e4f30eaa83..0c07520108 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -332,7 +332,9 @@ testrunner_SOURCES = \ test-dnsdistkvs_cc.cc \ test-dnsdistlbpolicies_cc.cc \ test-dnsdistluanetwork.cc \ + test-dnsdistnghttp2-in_cc.cc \ test-dnsdistnghttp2_cc.cc \ + test-dnsdistnghttp2_common.hh \ test-dnsdistpacketcache_cc.cc \ test-dnsdistrings_cc.cc \ test-dnsdistrules_cc.cc \ diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc index 21098ec96e..b71601a916 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -142,7 +142,7 @@ public: d_query.d_contentTypeOut = contentType; } - void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr& downstream) override + void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr& downstream_) override { std::unique_ptr unit(this); auto conn = d_connection.lock(); @@ -153,7 +153,7 @@ public: state.du = std::move(unit); TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr); - resp.d_ds = downstream; + resp.d_ds = downstream_; struct timeval now { }; @@ -263,7 +263,7 @@ IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& connectionInfo bool IncomingHTTP2Connection::checkALPN() { constexpr std::array h2ALPN{'h', '2'}; - auto protocols = d_handler.getNextProtocol(); + const auto protocols = d_handler.getNextProtocol(); if (protocols.size() == h2ALPN.size() && memcmp(protocols.data(), h2ALPN.data(), h2ALPN.size()) == 0) { return true; } @@ -285,6 +285,11 @@ void IncomingHTTP2Connection::handleConnectionReady() } } +bool IncomingHTTP2Connection::hasPendingWrite() const +{ + return d_pendingWrite; +} + void IncomingHTTP2Connection::handleIO() { IOState iostate = IOState::Done; @@ -297,7 +302,7 @@ void IncomingHTTP2Connection::handleIO() if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); stopIO(); - d_connectionDied = true; + d_connectionClosing = true; return; } @@ -341,56 +346,94 @@ void IncomingHTTP2Connection::handleIO() } } - if (d_state == State::waitingForQuery || d_state == State::idle) { - readHTTPData(); + if (active() && !d_connectionClosing && (d_state == State::waitingForQuery || d_state == State::idle)) { + do { + iostate = readHTTPData(); + } while (active() && !d_connectionClosing && iostate == IOState::Done); } - if (!d_connectionDied) { - auto shared = std::dynamic_pointer_cast(shared_from_this()); + if (!active()) { + stopIO(); + return; + } + /* + So: + - if we have a pending write, we need to wait until the socket becomes writable + and then call handleWritableCallback + - if we have NeedWrite but no pending write, we need to wait until the socket + becomes writable but for handleReadableIOCallback + - if we have NeedRead, or nghttp2_session_want_read, wait until the socket + becomes readable and call handleReadableIOCallback + */ + if (hasPendingWrite()) { + updateIO(IOState::NeedWrite, handleWritableIOCallback); + } + else if (iostate == IOState::NeedWrite) { + updateIO(IOState::NeedWrite, handleReadableIOCallback); + } + else if (!d_connectionClosing) { if (nghttp2_session_want_read(d_session.get()) != 0) { - d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none); + updateIO(IOState::NeedRead, handleReadableIOCallback); } - if (nghttp2_session_want_write(d_session.get()) != 0) { - d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none); + else { + if (isIdle()) { + watchForRemoteHostClosingConnection(); + } } } } catch (const std::exception& e) { - vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); + infolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); d_connectionDied = true; stopIO(); } } -ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) +void IncomingHTTP2Connection::writeToSocket(bool socketReady) { - auto* conn = static_cast(user_data); - // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API - conn->d_out.insert(conn->d_out.end(), data, data + length); - - if (conn->d_connectionDied || conn->d_needFlush) { - try { - conn->d_needFlush = false; - auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); - if (state == IOState::Done) { - conn->d_out.clear(); - conn->d_outPos = 0; - if (!conn->isIdle()) { - conn->updateIO(IOState::NeedRead, handleReadableIOCallback); + try { + d_needFlush = false; + IOState newState = d_handler.tryWrite(d_out, d_outPos, d_out.size()); + + if (newState == IOState::Done) { + d_pendingWrite = false; + d_out.clear(); + d_outPos = 0; + if (active() && !d_connectionClosing) { + if (!isIdle()) { + updateIO(IOState::NeedRead, handleReadableIOCallback); } else { - conn->watchForRemoteHostClosingConnection(); + watchForRemoteHostClosingConnection(); } } else { - conn->updateIO(state, handleWritableIOCallback); + stopIO(); } } - catch (const std::exception& e) { - vinfolog("Exception while trying to write (send) to incoming HTTP connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what()); - conn->handleIOError(); + else { + updateIO(newState, handleWritableIOCallback); + d_pendingWrite = true; } } + catch (const std::exception& e) { + vinfolog("Exception while trying to write (%s) to HTTP client connection to %s: %s", (socketReady ? "ready" : "send"), d_ci.remote.toStringWithPort(), e.what()); + handleIOError(); + } +} + +ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) +{ + auto* conn = static_cast(user_data); + if (conn->d_connectionDied) { + return static_cast(length); + } + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API + conn->d_out.insert(conn->d_out.end(), data, data + length); + + if (conn->d_connectionClosing || conn->d_needFlush) { + conn->writeToSocket(false); + } return static_cast(length); } @@ -471,7 +514,7 @@ IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResp sendResponse(response.d_idstate.d_streamID, context, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType); handleResponseSent(response); - return IOState::Done; + return hasPendingWrite() ? IOState::NeedWrite : IOState::Done; } void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response) @@ -748,6 +791,12 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi sendResponse(streamID, query, code, d_ci.cs->dohFrontend->d_customResponseHeaders); }; + if (query.d_method == PendingQuery::Method::Unknown || + query.d_method == PendingQuery::Method::Unsupported) { + handleImmediateResponse(400, "DoH query not allowed because of unsupported HTTP method"); + return; + } + ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries; if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) { @@ -864,44 +913,8 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) { auto* conn = static_cast(user_data); -#if 0 - switch (frame->hd.type) { - case NGHTTP2_HEADERS: - cerr<<"got headers"<headers.cat == NGHTTP2_HCAT_RESPONSE) { - cerr<<"All headers received"<headers.cat == NGHTTP2_HCAT_REQUEST) { - cerr<<"All headers received - query"<settings.niv<settings.niv; idx++) { - cerr<<"- "<settings.iv[idx].settings_id<<" "<settings.iv[idx].value<hd.type == NGHTTP2_GOAWAY) { - conn->stopIO(); - if (conn->isIdle()) { - if (nghttp2_session_want_write(conn->d_session.get()) != 0) { - conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none); - } - } - } - /* is this the last frame for this stream? */ - else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) { + if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) { auto streamID = frame->hd.stream_id; auto stream = conn->d_currentStreams.find(streamID); if (stream != conn->d_currentStreams.end()) { @@ -959,7 +972,8 @@ int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session, if (!insertPair.second) { /* there is a stream ID collision, something is very wrong! */ vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort()); - conn->d_connectionDied = true; + conn->d_connectionClosing = true; + conn->d_needFlush = true; nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR); auto ret = nghttp2_session_send(conn->d_session.get()); if (ret != 0) { @@ -1047,8 +1061,9 @@ int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const query.d_method = PendingQuery::Method::Post; } else { + query.d_method = PendingQuery::Method::Unsupported; vinfolog("Unsupported method value"); - return NGHTTP2_ERR_CALLBACK_FAILURE; + return 0; } } @@ -1087,7 +1102,8 @@ int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib auto* conn = static_cast(user_data); vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len)); - conn->d_connectionDied = true; + conn->d_connectionClosing = true; + conn->d_needFlush = true; nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR); auto ret = nghttp2_session_send(conn->d_session.get()); if (ret != 0) { @@ -1098,55 +1114,35 @@ int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib return 0; } -void IncomingHTTP2Connection::readHTTPData() +IOState IncomingHTTP2Connection::readHTTPData() { IOState newState = IOState::Done; - IOStateGuard ioGuard(d_ioState); - do { - size_t got = 0; - if (d_in.size() < 128) { - d_in.resize(std::max(static_cast(128U), d_in.capacity())); - } - try { - newState = d_handler.tryRead(d_in, got, d_in.size(), true); - d_in.resize(got); - - if (got > 0) { - /* we got something */ - auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size()); - /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB, - all data should be consumed before returning */ - if (readlen < 0 || static_cast(readlen) < d_in.size()) { - throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen))); - } - - nghttp2_session_send(d_session.get()); + size_t got = 0; + if (d_in.size() < s_initialReceiveBufferSize) { + d_in.resize(std::max(s_initialReceiveBufferSize, d_in.capacity())); + } + try { + newState = d_handler.tryRead(d_in, got, d_in.size(), true); + d_in.resize(got); + + if (got > 0) { + /* we got something */ + auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size()); + /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB, + all data should be consumed before returning */ + if (readlen < 0 || static_cast(readlen) < d_in.size()) { + throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen))); } - if (newState == IOState::Done) { - if (nghttp2_session_want_read(d_session.get()) != 0) { - continue; - } - if (isIdle()) { - watchForRemoteHostClosingConnection(); - ioGuard.release(); - break; - } - } - else { - if (newState == IOState::NeedWrite) { - updateIO(IOState::NeedWrite, handleReadableIOCallback); - } - ioGuard.release(); - break; - } + nghttp2_session_send(d_session.get()); } - catch (const std::exception& e) { - vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what()); - handleIOError(); - break; - } - } while (newState == IOState::Done || !isIdle()); + } + catch (const std::exception& e) { + vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what()); + handleIOError(); + return IOState::Done; + } + return newState; } void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param) @@ -1158,29 +1154,7 @@ void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int desc void IncomingHTTP2Connection::handleWritableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param) { auto conn = boost::any_cast>(param); - IOStateGuard ioGuard(conn->d_ioState); - - try { - IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); - if (newState == IOState::NeedRead) { - conn->updateIO(IOState::NeedRead, handleWritableIOCallback); - } - else if (newState == IOState::Done) { - conn->d_out.clear(); - conn->d_outPos = 0; - if (!conn->isIdle()) { - conn->updateIO(IOState::NeedRead, handleReadableIOCallback); - } - else { - conn->watchForRemoteHostClosingConnection(); - } - } - ioGuard.release(); - } - catch (const std::exception& e) { - vinfolog("Exception while trying to write (ready) to HTTP client connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what()); - conn->handleIOError(); - } + conn->writeToSocket(true); } bool IncomingHTTP2Connection::isIdle() const @@ -1250,14 +1224,31 @@ void IncomingHTTP2Connection::updateIO(IOState newState, const FDMultiplexer::ca void IncomingHTTP2Connection::watchForRemoteHostClosingConnection() { - updateIO(IOState::NeedRead, handleReadableIOCallback); + if (d_connectionDied) { + return; + } + + if (hasPendingWrite()) { + updateIO(IOState::NeedWrite, &handleWritableIOCallback); + } + else if (!d_connectionClosing) { + updateIO(IOState::NeedRead, handleReadableIOCallback); + } } void IncomingHTTP2Connection::handleIOError() { d_connectionDied = true; + d_out.clear(); + d_outPos = 0; nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR); d_currentStreams.clear(); stopIO(); } + +bool IncomingHTTP2Connection::active() const +{ + return !d_connectionDied && d_ioState != nullptr; +} + #endif /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh index e68d214208..3db7473a8e 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh @@ -39,7 +39,8 @@ public: { Unknown, Get, - Post + Post, + Unsupported }; PacketBuffer d_buffer; @@ -61,6 +62,7 @@ public: void handleIO() override; void handleResponse(const struct timeval& now, TCPResponse&& response) override; void notifyIOError(const struct timeval& now, TCPResponse&& response) override; + bool active() const override; private: static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data); @@ -73,6 +75,8 @@ private: static void handleReadableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param); static void handleWritableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param); + static constexpr size_t s_initialReceiveBufferSize{256U}; + IOState sendResponse(const struct timeval& now, TCPResponse&& response) override; bool forwardViaUDPFirst() const override { @@ -90,8 +94,10 @@ private: bool sendResponse(StreamID streamID, PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true); void handleIncomingQuery(PendingQuery&& query, StreamID streamID); bool checkALPN(); - void readHTTPData(); + IOState readHTTPData(); void handleConnectionReady(); + bool hasPendingWrite() const; + void writeToSocket(bool socketReady); boost::optional getIdleClientReadTTD(struct timeval now) const; std::unique_ptr d_session{nullptr, nghttp2_session_del}; @@ -99,8 +105,18 @@ private: PacketBuffer d_out; PacketBuffer d_in; size_t d_outPos{0}; + /* this connection is done, the remote end has closed the connection + or something like that. We do not want to try to write to it. */ bool d_connectionDied{false}; + /* we are done reading from this connection, but we might still want to + write to it to close it properly */ + bool d_connectionClosing{false}; + /* Whether we are still waiting for more data to be buffered + before writing to the socket (false) or not. */ bool d_needFlush{false}; + /* Whether we have data that we want to write to the socket, + but the socket is full. */ + bool d_pendingWrite{false}; }; class NGHTTP2Headers diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 03fa0bfca5..9b774f0ed4 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -300,10 +300,9 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, */ nghttp2_data_provider data_provider; - /* we will not use this pointer */ data_provider.source.ptr = this; data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t { - auto conn = reinterpret_cast(user_data); + auto conn = static_cast(user_data); auto& request = conn->d_currentStreams.at(stream_id); size_t toCopy = 0; if (request.d_queryPos < request.d_query.d_buffer.size()) { diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc new file mode 100644 index 0000000000..0ac62b3c3b --- /dev/null +++ b/pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc @@ -0,0 +1,727 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#define BOOST_TEST_DYN_LINK +#define BOOST_TEST_NO_MAIN + +#include + +#include "dnswriter.hh" +#include "dnsdist.hh" +#include "dnsdist-proxy-protocol.hh" +#include "dnsdist-nghttp2-in.hh" + +#ifdef HAVE_NGHTTP2 +#include + +extern std::function& selectedBackend)> s_processQuery; + +BOOST_AUTO_TEST_SUITE(test_dnsdistnghttp2_in_cc) + +struct ExpectedStep +{ +public: + enum class ExpectedRequest + { + handshakeClient, + readFromClient, + writeToClient, + closeClient, + }; + + ExpectedStep(ExpectedRequest r, IOState n, size_t b = 0, std::function fn = nullptr) : + cb(fn), request(r), nextState(n), bytes(b) + { + } + + std::function cb{nullptr}; + ExpectedRequest request; + IOState nextState; + size_t bytes{0}; +}; + +struct ExpectedData +{ + PacketBuffer d_proxyProtocolPayload; + std::vector d_queries; + std::vector d_responses; + std::vector d_responseCodes; +}; + +class DOHConnection; + +static std::deque s_steps; +static std::map s_connectionContexts; +static std::map> s_connectionBuffers; +static uint64_t s_connectionID{0}; + +std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d); + +std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d) +{ + static const std::vector requests = {"handshake with client", "read from client", "write to client", "close connection to client", "connect to the backend", "read from the backend", "write to the backend", "close connection to backend"}; + os << requests.at(static_cast(d)); + return os; +} + +class DOHConnection +{ +public: + DOHConnection(uint64_t connectionID) : + d_session(std::unique_ptr(nullptr, nghttp2_session_del)), d_connectionID(connectionID) + { + const auto& context = s_connectionContexts.at(connectionID); + d_clientOutBuffer.insert(d_clientOutBuffer.begin(), context.d_proxyProtocolPayload.begin(), context.d_proxyProtocolPayload.end()); + + nghttp2_session_callbacks* cbs = nullptr; + nghttp2_session_callbacks_new(&cbs); + std::unique_ptr callbacks(cbs, nghttp2_session_callbacks_del); + cbs = nullptr; + nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback); + nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback); + nghttp2_session* sess = nullptr; + nghttp2_session_client_new(&sess, callbacks.get(), this); + d_session = std::unique_ptr(sess, nghttp2_session_del); + + nghttp2_settings_entry iv[] = { + /* rfc7540 section-8.2.2: + "Advertising a SETTINGS_MAX_CONCURRENT_STREAMS value of zero disables + server push by preventing the server from creating the necessary + streams." + */ + {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 0}, + {NGHTTP2_SETTINGS_ENABLE_PUSH, 0}, + /* we might want to make the initial window size configurable, but 16M is a large enough default */ + {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16 * 1024 * 1024}}; + /* client 24 bytes magic string will be sent by nghttp2 library */ + auto result = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv) / sizeof(*iv)); + if (result != 0) { + throw std::runtime_error("Error submitting settings:" + std::string(nghttp2_strerror(result))); + } + + const std::string host("unit-tests"); + const std::string path("/dns-query"); + for (const auto& query : context.d_queries) { + const auto querySize = std::to_string(query.size()); + std::vector headers; + /* Pseudo-headers need to come first (rfc7540 8.1.2.1) */ + NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::METHOD_NAME, NGHTTP2Headers::HeaderConstantIndexes::METHOD_VALUE); + NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_NAME, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_VALUE); + NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::AUTHORITY_NAME, host); + NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::PATH_NAME, path); + NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_NAME, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_VALUE); + NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_VALUE); + NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_NAME, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_VALUE); + NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_LENGTH_NAME, querySize); + + d_position = 0; + d_currentQuery = query; + nghttp2_data_provider data_provider; + data_provider.source.ptr = this; + data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t { + auto* conn = static_cast(user_data); + auto& pos = conn->d_position; + const auto& currentQuery = conn->d_currentQuery; + size_t toCopy = 0; + if (pos < currentQuery.size()) { + size_t remaining = currentQuery.size() - pos; + toCopy = length > remaining ? remaining : length; + memcpy(buf, ¤tQuery.at(pos), toCopy); + pos += toCopy; + } + + if (pos >= currentQuery.size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return toCopy; + }; + + auto newStreamId = nghttp2_submit_request(d_session.get(), nullptr, headers.data(), headers.size(), &data_provider, this); + if (newStreamId < 0) { + throw std::runtime_error("Error submitting HTTP request:" + std::string(nghttp2_strerror(newStreamId))); + } + + result = nghttp2_session_send(d_session.get()); + if (result != 0) { + throw std::runtime_error("Error in nghttp2_session_send:" + std::to_string(result)); + } + } + } + + std::map d_responses; + std::map d_responseCodes; + std::unique_ptr d_session; + PacketBuffer d_currentQuery; + PacketBuffer d_clientOutBuffer; + uint64_t d_connectionID{0}; + size_t d_position{0}; + + size_t submitIncoming(const PacketBuffer& data, size_t pos, size_t toWrite) + { + ssize_t readlen = nghttp2_session_mem_recv(d_session.get(), &data.at(pos), toWrite); + if (readlen < 0) { + throw("Fatal error while submitting line " + std::to_string(__LINE__) + ": " + std::string(nghttp2_strerror(static_cast(readlen)))); + } + + /* just in case, see if we have anything to send */ + int rv = nghttp2_session_send(d_session.get()); + if (rv != 0) { + throw("Fatal error while sending: " + std::string(nghttp2_strerror(rv))); + } + + return readlen; + } + +private: + static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) + { + DOHConnection* conn = static_cast(user_data); + conn->d_clientOutBuffer.insert(conn->d_clientOutBuffer.end(), data, data + length); + return static_cast(length); + } + + static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) + { + DOHConnection* conn = static_cast(user_data); + if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + const auto& response = conn->d_responses.at(frame->hd.stream_id); + if (conn->d_responseCodes.at(frame->hd.stream_id) != 200U) { + return 0; + } + + BOOST_REQUIRE_GT(response.size(), sizeof(dnsheader)); + const auto* dh = reinterpret_cast(response.data()); + uint16_t id = ntohs(dh->id); + + const auto& expected = s_connectionContexts.at(conn->d_connectionID).d_responses.at(id); + BOOST_REQUIRE_EQUAL(expected.size(), response.size()); + for (size_t idx = 0; idx < response.size(); idx++) { + if (expected.at(idx) != response.at(idx)) { + cerr << "Mismatch at offset " << idx << ", expected " << std::to_string(response.at(idx)) << " got " << std::to_string(expected.at(idx)) << endl; + BOOST_CHECK(false); + } + } + } + + return 0; + } + + static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data) + { + DOHConnection* conn = static_cast(user_data); + auto& response = conn->d_responses[stream_id]; + response.insert(response.end(), data, data + len); + return 0; + } + + static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) + { + DOHConnection* conn = static_cast(user_data); + + const std::string status(":status"); + if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { + if (namelen == status.size() && memcmp(status.data(), name, status.size()) == 0) { + try { + uint16_t responseCode{0}; + auto expected = s_connectionContexts.at(conn->d_connectionID).d_responseCodes.at((frame->hd.stream_id - 1) / 2); + pdns::checked_stoi_into(responseCode, std::string(reinterpret_cast(value), valuelen)); + conn->d_responseCodes[frame->hd.stream_id] = responseCode; + if (responseCode != expected) { + cerr << "Mismatch response code, expected " << std::to_string(expected) << " got " << std::to_string(responseCode) << endl; + BOOST_CHECK(false); + } + } + catch (const std::exception& e) { + infolog("Error parsing the status header for stream ID %d: %s", frame->hd.stream_id, e.what()); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + } + return 0; + } + + static int on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data) + { + return 0; + } +}; + +class MockupTLSConnection : public TLSConnection +{ +public: + MockupTLSConnection(int descriptor, [[maybe_unused]] bool client = false, [[maybe_unused]] bool needProxyProtocol = false) : + d_descriptor(descriptor) + { + auto connectionID = s_connectionID++; + auto conn = std::make_unique(connectionID); + s_connectionBuffers[d_descriptor] = std::move(conn); + } + + ~MockupTLSConnection() {} + + IOState tryHandshake() override + { + auto step = getStep(); + BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::handshakeClient); + + return step.nextState; + } + + IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override + { + auto& conn = s_connectionBuffers.at(d_descriptor); + auto step = getStep(); + BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::writeToClient); + + if (step.bytes == 0) { + if (step.nextState == IOState::NeedWrite) { + return step.nextState; + } + throw std::runtime_error("Remote host closed the connection"); + } + + toWrite -= pos; + BOOST_REQUIRE_GE(buffer.size(), pos + toWrite); + + if (step.bytes < toWrite) { + toWrite = step.bytes; + } + + conn->submitIncoming(buffer, pos, toWrite); + pos += toWrite; + + return step.nextState; + } + + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete = false) override + { + auto& conn = s_connectionBuffers.at(d_descriptor); + auto step = getStep(); + BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::readFromClient); + + if (step.bytes == 0) { + if (step.nextState == IOState::NeedRead) { + return step.nextState; + } + throw std::runtime_error("Remote host closed the connection"); + } + + auto& externalBuffer = conn->d_clientOutBuffer; + toRead -= pos; + + if (step.bytes < toRead) { + toRead = step.bytes; + } + if (allowIncomplete) { + if (toRead > externalBuffer.size()) { + toRead = externalBuffer.size(); + } + } + else { + BOOST_REQUIRE_GE(externalBuffer.size(), toRead); + } + + BOOST_REQUIRE_GE(buffer.size(), toRead); + + std::copy(externalBuffer.begin(), externalBuffer.begin() + toRead, buffer.begin() + pos); + pos += toRead; + externalBuffer.erase(externalBuffer.begin(), externalBuffer.begin() + toRead); + + return step.nextState; + } + + IOState tryConnect(bool fastOpen, const ComboAddress& remote) override + { + throw std::runtime_error("Should not happen"); + } + + void close() override + { + auto step = getStep(); + BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::closeClient); + } + + bool hasBufferedData() const override + { + return false; + } + + bool isUsable() const override + { + return true; + } + + std::string getServerNameIndication() const override + { + return ""; + } + + std::vector getNextProtocol() const override + { + return std::vector{'h', '2'}; + } + + LibsslTLSVersion getTLSVersion() const override + { + return LibsslTLSVersion::TLS13; + } + + bool hasSessionBeenResumed() const override + { + return false; + } + + std::vector> getSessions() override + { + return {}; + } + + void setSession(std::unique_ptr& session) override + { + } + + std::vector getAsyncFDs() override + { + return {}; + } + + /* unused in that context, don't bother */ + void doHandshake() override + { + } + + void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override + { + } + + size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0, 0}, bool allowIncomplete = false) override + { + return 0; + } + + size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override + { + return 0; + } + +private: + ExpectedStep getStep() const + { + BOOST_REQUIRE(!s_steps.empty()); + auto step = s_steps.front(); + s_steps.pop_front(); + + if (step.cb) { + step.cb(d_descriptor); + } + + return step; + } + + const int d_descriptor; +}; + +#include "test-dnsdistnghttp2_common.hh" + +struct TestFixture +{ + TestFixture() + { + s_steps.clear(); + s_connectionContexts.clear(); + s_connectionBuffers.clear(); + s_connectionID = 0; + s_mplexer = std::make_unique(); + } + ~TestFixture() + { + s_steps.clear(); + s_connectionContexts.clear(); + s_connectionBuffers.clear(); + s_connectionID = 0; + s_mplexer.reset(); + } +}; + +BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_SelfAnswered, TestFixture) +{ + auto local = getBackendAddress("1", 80); + ClientState localCS(local, true, false, false, "", {}); + localCS.dohFrontend = std::make_shared(std::make_shared()); + localCS.dohFrontend->d_urls.insert("/dns-query"); + + TCPClientThreadData threadData; + threadData.mplexer = std::make_unique(); + + struct timeval now; + gettimeofday(&now, nullptr); + + size_t counter = 0; + DNSName name("powerdns.com."); + PacketBuffer query; + GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); + pwQ.getHeader()->rd = 1; + pwQ.getHeader()->id = htons(counter); + + PacketBuffer response; + GenericDNSPacketWriter pwR(response, name, QType::A, QClass::IN, 0); + pwR.getHeader()->qr = 1; + pwR.getHeader()->rd = 1; + pwR.getHeader()->ra = 1; + pwR.getHeader()->id = htons(counter); + pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); + pwR.xfr32BitInt(0x01020304); + pwR.commit(); + + { + /* dnsdist drops the query right away after receiving it, client closes the connection */ + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {403U}}; + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* settings + headers + data client -> server.. */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 }, + /* .. continued */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 }, + /* headers + data */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits::max() }, + /* wait for next query, but the client closes the connection */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + } + + { + /* client closes the connection right in the middle of sending the query */ + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, { 403U }}; + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* client sends one byte */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 1 }, + /* then closes the connection */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + /* mark the incoming FD as always ready */ + dynamic_cast(threadData.mplexer.get())->setReady(-1); + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { + threadData.mplexer->run(&now); + } + } + + { + /* dnsdist sends a response right away, client closes the connection after getting the response */ + s_processQuery = [response](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + /* self answered */ + dq.getMutableData() = response; + return ProcessQueryResult::SendAnswer; + }; + + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}}; + + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* settings + headers + data client -> server.. */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 }, + /* .. continued */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 }, + /* headers + data */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits::max() }, + /* wait for next query, but the client closes the connection */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + } + + { + /* dnsdist sends a response right away, but the client closes the connection without even reading the response */ + s_processQuery = [response](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + /* self answered */ + dq.getMutableData() = response; + return ProcessQueryResult::SendAnswer; + }; + + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}}; + + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* settings + headers + data client -> server.. */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 }, + /* .. continued */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 }, + /* we want to send the response but the client closes the connection */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + /* mark the incoming FD as always ready */ + dynamic_cast(threadData.mplexer.get())->setReady(-1); + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { + threadData.mplexer->run(&now); + } + } + + { + /* dnsdist sends a response right away, client closes the connection while getting the response */ + s_processQuery = [response](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + /* self answered */ + dq.getMutableData() = response; + return ProcessQueryResult::SendAnswer; + }; + + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}}; + + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* settings + headers + data client -> server.. */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 }, + /* .. continued */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 }, + /* headers + data (partial write) */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::NeedWrite, 1 }, + /* nothing to read after that */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 }, + /* then the client closes the connection before we are done */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + /* mark the incoming FD as always ready */ + dynamic_cast(threadData.mplexer.get())->setReady(-1); + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { + threadData.mplexer->run(&now); + } + } +} + +BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_BackendTimeout, TestFixture) +{ + auto local = getBackendAddress("1", 80); + ClientState localCS(local, true, false, false, "", {}); + localCS.dohFrontend = std::make_shared(std::make_shared()); + localCS.dohFrontend->d_urls.insert("/dns-query"); + + TCPClientThreadData threadData; + threadData.mplexer = std::make_unique(); + + auto backend = std::make_shared(getBackendAddress("42", 53)); + + struct timeval now; + gettimeofday(&now, nullptr); + + size_t counter = 0; + DNSName name("powerdns.com."); + PacketBuffer query; + GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); + pwQ.getHeader()->rd = 1; + pwQ.getHeader()->id = htons(counter); + + PacketBuffer response; + GenericDNSPacketWriter pwR(response, name, QType::A, QClass::IN, 0); + pwR.getHeader()->qr = 1; + pwR.getHeader()->rd = 1; + pwR.getHeader()->ra = 1; + pwR.getHeader()->id = htons(counter); + pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); + pwR.xfr32BitInt(0x01020304); + pwR.commit(); + + { + /* dnsdist forwards the query to the backend, which does not answer -> timeout */ + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + selectedBackend = backend; + return ProcessQueryResult::PassToBackend; + }; + s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {502U}}; + s_steps = { + /* opening */ + { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + /* settings server -> client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 }, + /* settings + headers + data client -> server.. */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 }, + /* .. continued */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 }, + /* trying to read a new request while processing the first one */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead }, + /* headers + data */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits::max(), [&threadData](int desc) { + /* set the incoming descriptor as ready */ + dynamic_cast(threadData.mplexer.get())->setReady(desc); + } + }, + /* wait for next query, but the client closes the connection */ + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, + /* server close */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, + }; + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + state->handleIO(); + TCPResponse resp; + resp.d_idstate.d_streamID = 1; + state->notifyIOError(now, std::move(resp)); + while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { + threadData.mplexer->run(&now); + } + } +} + +BOOST_AUTO_TEST_SUITE_END(); +#endif /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index d10e85ef13..3e5bb16312 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -486,110 +486,7 @@ private: bool d_client{false}; }; -class MockupTLSCtx : public TLSCtx -{ -public: - ~MockupTLSCtx() - { - } - - std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override - { - return std::make_unique(socket); - } - - std::unique_ptr getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override - { - return std::make_unique(socket, true, d_needProxyProtocol); - } - - void rotateTicketsKey(time_t now) override - { - } - - size_t getTicketsKeysCount() override - { - return 0; - } - - std::string getName() const override - { - return "Mockup TLS"; - } - - bool d_needProxyProtocol{false}; -}; - -class MockupFDMultiplexer : public FDMultiplexer -{ -public: - MockupFDMultiplexer() - { - } - - ~MockupFDMultiplexer() - { - } - - int run(struct timeval* tv, int timeout = 500) override - { - int ret = 0; - - gettimeofday(tv, nullptr); // MANDATORY - - /* 'ready' might be altered by a callback while we are iterating */ - const auto readyFDs = ready; - for (const auto fd : readyFDs) { - { - const auto& it = d_readCallbacks.find(fd); - - if (it != d_readCallbacks.end()) { - it->d_callback(it->d_fd, it->d_parameter); - } - } - - { - const auto& it = d_writeCallbacks.find(fd); - - if (it != d_writeCallbacks.end()) { - it->d_callback(it->d_fd, it->d_parameter); - } - } - } - - return ret; - } - - void getAvailableFDs(std::vector& fds, int timeout) override - { - } - - void addFD(int fd, FDMultiplexer::EventKind kind) override - { - } - - void removeFD(int fd, FDMultiplexer::EventKind) override - { - } - - string getName() const override - { - return "mockup"; - } - - void setReady(int fd) - { - ready.insert(fd); - } - - void setNotReady(int fd) - { - ready.erase(fd); - } - -private: - std::set ready; -}; +#include "test-dnsdistnghttp2_common.hh" class MockupQuerySender : public TCPQuerySender { @@ -641,36 +538,6 @@ public: bool d_error{false}; }; -static bool isIPv6Supported() -{ - try { - ComboAddress addr("[2001:db8:53::1]:53"); - auto socket = std::make_unique(addr.sin4.sin_family, SOCK_STREAM, 0); - socket->setNonBlocking(); - int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0}); - if (res == 0 || res == EINPROGRESS) { - return true; - } - return false; - } - catch (const std::exception& e) { - return false; - } -} - -static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port) -{ - static const bool useV6 = isIPv6Supported(); - - if (useV6) { - return ComboAddress("2001:db8:53::" + lastDigit, port); - } - - return ComboAddress("192.0.2." + lastDigit, port); -} - -static std::unique_ptr s_mplexer; - struct TestFixture { TestFixture() diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_common.hh b/pdns/dnsdistdist/test-dnsdistnghttp2_common.hh new file mode 100644 index 0000000000..5c79679aba --- /dev/null +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_common.hh @@ -0,0 +1,157 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +class MockupTLSCtx : public TLSCtx +{ +public: + ~MockupTLSCtx() + { + } + + std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override + { + return std::make_unique(socket); + } + + std::unique_ptr getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override + { + return std::make_unique(socket, true, d_needProxyProtocol); + } + + void rotateTicketsKey(time_t now) override + { + } + + size_t getTicketsKeysCount() override + { + return 0; + } + + std::string getName() const override + { + return "Mockup TLS"; + } + + bool d_needProxyProtocol{false}; +}; + +class MockupFDMultiplexer : public FDMultiplexer +{ +public: + MockupFDMultiplexer() + { + } + + ~MockupFDMultiplexer() + { + } + + int run(struct timeval* tv, int timeout = 500) override + { + int ret = 0; + + gettimeofday(tv, nullptr); // MANDATORY + + /* 'ready' might be altered by a callback while we are iterating */ + const auto readyFDs = ready; + for (const auto fd : readyFDs) { + { + const auto& it = d_readCallbacks.find(fd); + + if (it != d_readCallbacks.end()) { + it->d_callback(it->d_fd, it->d_parameter); + } + } + + { + const auto& it = d_writeCallbacks.find(fd); + + if (it != d_writeCallbacks.end()) { + it->d_callback(it->d_fd, it->d_parameter); + } + } + } + + return ret; + } + + void getAvailableFDs(std::vector& fds, int timeout) override + { + } + + void addFD(int fd, FDMultiplexer::EventKind kind) override + { + } + + void removeFD(int fd, FDMultiplexer::EventKind) override + { + } + + string getName() const override + { + return "mockup"; + } + + void setReady(int fd) + { + ready.insert(fd); + } + + void setNotReady(int fd) + { + ready.erase(fd); + } + +private: + std::set ready; +}; + +static bool isIPv6Supported() +{ + try { + ComboAddress addr("[2001:db8:53::1]:53"); + auto socket = std::make_unique(addr.sin4.sin_family, SOCK_STREAM, 0); + socket->setNonBlocking(); + int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0}); + if (res == 0 || res == EINPROGRESS) { + return true; + } + return false; + } + catch (const std::exception& e) { + return false; + } +} + +static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port) +{ + static const bool useV6 = isIPv6Supported(); + + if (useV6) { + return ComboAddress("2001:db8:53::" + lastDigit, port); + } + + return ComboAddress("192.0.2." + lastDigit, port); +} + +static std::unique_ptr s_mplexer; diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 22e137c24b..dedfd97d2b 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -62,7 +62,7 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo { } -static std::function& selectedBackend)> s_processQuery; +std::function& selectedBackend)> s_processQuery; ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend) { diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 850273eb8a..c51a930c04 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -56,7 +56,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) { - return false; + return true; } namespace dnsdist { diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 1e7968aa3d..75068bea02 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -987,21 +987,32 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): return conn @classmethod - def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None): + def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, useProxyProtocol=False, conn=None): url = cls.getDOHGetURL(baseurl, query, rawQuery) - conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) - response_headers = BytesIO() - #conn.setopt(pycurl.VERBOSE, True) - conn.setopt(pycurl.URL, url) - conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) - # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" - conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) + + if not conn: + print('creating a new connection') + conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) + # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" + conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) + if useHTTPS: + print("disabling verify") conn.setopt(pycurl.SSL_VERIFYPEER, 1) conn.setopt(pycurl.SSL_VERIFYHOST, 2) if caFile: conn.setopt(pycurl.CAINFO, caFile) + if useProxyProtocol: + print('enabling PP') + # 274 is CURLOPT_HAPROXYPROTOCOL + conn.setopt(274, 1) + + response_headers = BytesIO() + #conn.setopt(pycurl.VERBOSE, True) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + conn.setopt(pycurl.HTTPHEADER, customHeaders) conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) @@ -1014,6 +1025,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): receivedQuery = None message = None cls._response_headers = '' + print('performing') data = conn.perform_rb() cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) if cls._rcode == 200 and not rawResponse: @@ -1076,8 +1088,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls._response_headers = response_headers.getvalue() return (receivedQuery, message) - def sendDOHQueryWrapper(self, query, response, useQueue=True): - return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) + def sendDOHQueryWrapper(self, query, response, useQueue=True, useProxyProtocol=False): + return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, useProxyProtocol=useProxyProtocol) def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True): return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) diff --git a/regression-tests.dnsdist/test_DOH.py b/regression-tests.dnsdist/test_DOH.py index ae6aac46a4..f9fce6be56 100644 --- a/regression-tests.dnsdist/test_DOH.py +++ b/regression-tests.dnsdist/test_DOH.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import base64 import dns import os import time @@ -32,6 +33,7 @@ class DOHTests(object): addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9")) addAction("http-status-action.doh.tests.powerdns.com.", HTTPStatusAction(200, "Plaintext answer", "text/plain")) addAction("http-status-action-redirect.doh.tests.powerdns.com.", HTTPStatusAction(307, "https://doh.powerdns.org")) + addAction("no-backend.doh.tests.powerdns.com.", PoolAction('this-pool-has-no-backend')) function dohHandler(dq) if dq:getHTTPScheme() == 'https' and dq:getHTTPHost() == '%s:%d' and dq:getHTTPPath() == '/' and dq:getHTTPQueryString() == '' then @@ -235,9 +237,133 @@ class DOHTests(object): (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, caFile=self._caCert, query=query, response=None, useQueue=False) self.assertEqual(receivedResponse, expectedResponse) + def testDOHWithoutQuery(self): + """ + DOH: Empty GET query + """ + name = 'empty-get.doh.tests.powerdns.com.' + url = self._dohBaseURL + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 400) + + def testDOHShortPath(self): + """ + DOH: Short path in GET query + """ + name = 'short-path-get.doh.tests.powerdns.com.' + url = self._dohBaseURL + '/AA' + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 404) + + def testDOHQueryNoParameter(self): + """ + DOH: No parameter GET query + """ + name = 'no-parameter-get.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + wire = query.to_wire() + b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + url = self._dohBaseURL + '?not-dns=' + b64 + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 400) + + def testDOHQueryInvalidBase64(self): + """ + DOH: Invalid Base64 GET query + """ + name = 'invalid-b64-get.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + wire = query.to_wire() + url = self._dohBaseURL + '?dns=' + '_-~~~~-_' + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 400) + + def testDOHInvalidDNSHeaders(self): + """ + DOH: Invalid DNS headers + """ + name = 'invalid-dns-headers.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.flags |= dns.flags.QR + wire = query.to_wire() + b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + url = self._dohBaseURL + '?dns=' + b64 + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 400) + + def testDOHQueryInvalidMethod(self): + """ + DOH: Invalid method + """ + if self._dohLibrary == 'h2o': + raise unittest.SkipTest('h2o does not check the HTTP method') + name = 'invalid-method.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + wire = query.to_wire() + b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + url = self._dohBaseURL + '?dns=' + b64 + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + conn.setopt(pycurl.CAINFO, self._caCert) + conn.setopt(pycurl.CUSTOMREQUEST, 'PATCH') + data = conn.perform_rb() + rcode = conn.getinfo(pycurl.RESPONSE_CODE) + self.assertEqual(rcode, 400) + + def testDOHQueryInvalidALPN(self): + """ + DOH: Invalid ALPN + """ + alpn = ['bogus-alpn'] + conn = self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=alpn) + try: + conn.send('AAAA') + response = conn.recv(65535) + self.assertFalse(response) + except: + pass + def testDOHInvalid(self): """ - DOH: Invalid query + DOH: Invalid DNS query """ name = 'invalid.doh.tests.powerdns.com.' invalidQuery = dns.message.make_query(name, 'A', 'IN', use_edns=False) @@ -268,13 +394,43 @@ class DOHTests(object): self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery) self.assertEqual(response, receivedResponse) - def testDOHWithoutQuery(self): + def testDOHInvalidHeaderName(self): """ - DOH: Empty GET query + DOH: Invalid HTTP header name query """ - name = 'empty-get.doh.tests.powerdns.com.' - url = self._dohBaseURL - conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0) + name = 'invalid-header-name.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + # this header is invalid, see rfc9113 section 8.2.1. Field Validity + customHeaders = ['{}: test'] + try: + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, customHeaders=customHeaders) + self.assertFalse(receivedQuery) + self.assertFalse(receivedResponse) + except pycurl.error: + pass + + def testDOHNoBackend(self): + """ + DOH: No backend + """ + if self._dohLibrary == 'h2o': + raise unittest.SkipTest('h2o does not check the HTTP method') + name = 'no-backend.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + wire = query.to_wire() + b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + url = self._dohBaseURL + '?dns=' + b64 + conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2) conn.setopt(pycurl.URL, url) conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)]) conn.setopt(pycurl.SSL_VERIFYPEER, 1) @@ -282,7 +438,7 @@ class DOHTests(object): conn.setopt(pycurl.CAINFO, self._caCert) data = conn.perform_rb() rcode = conn.getinfo(pycurl.RESPONSE_CODE) - self.assertEqual(rcode, 400) + self.assertEqual(rcode, 403) def testDOHEmptyPOST(self): """