DOHFrontend()
{
}
+ DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx):
+ d_tlsContext(std::move(tlsCtx))
+ {
+ }
virtual ~DOHFrontend()
{
test-dnsdistkvs_cc.cc \
test-dnsdistlbpolicies_cc.cc \
test-dnsdistluanetwork.cc \
+ test-dnsdistnghttp2-in_cc.cc \
test-dnsdistnghttp2_cc.cc \
+ test-dnsdistnghttp2_common.hh \
test-dnsdistpacketcache_cc.cc \
test-dnsdistrings_cc.cc \
test-dnsdistrules_cc.cc \
d_query.d_contentTypeOut = contentType;
}
- void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& downstream) override
+ void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& downstream_) override
{
std::unique_ptr<DOHUnitInterface> unit(this);
auto conn = d_connection.lock();
state.du = std::move(unit);
TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr);
- resp.d_ds = downstream;
+ resp.d_ds = downstream_;
struct timeval now
{
};
bool IncomingHTTP2Connection::checkALPN()
{
constexpr std::array<uint8_t, 2> h2ALPN{'h', '2'};
- auto protocols = d_handler.getNextProtocol();
+ const auto protocols = d_handler.getNextProtocol();
if (protocols.size() == h2ALPN.size() && memcmp(protocols.data(), h2ALPN.data(), h2ALPN.size()) == 0) {
return true;
}
}
}
+bool IncomingHTTP2Connection::hasPendingWrite() const
+{
+ return d_pendingWrite;
+}
+
void IncomingHTTP2Connection::handleIO()
{
IOState iostate = IOState::Done;
if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
stopIO();
- d_connectionDied = true;
+ d_connectionClosing = true;
return;
}
}
}
- if (d_state == State::waitingForQuery || d_state == State::idle) {
- readHTTPData();
+ if (active() && !d_connectionClosing && (d_state == State::waitingForQuery || d_state == State::idle)) {
+ do {
+ iostate = readHTTPData();
+ } while (active() && !d_connectionClosing && iostate == IOState::Done);
}
- if (!d_connectionDied) {
- auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+ if (!active()) {
+ stopIO();
+ return;
+ }
+ /*
+ So:
+ - if we have a pending write, we need to wait until the socket becomes writable
+ and then call handleWritableCallback
+ - if we have NeedWrite but no pending write, we need to wait until the socket
+ becomes writable but for handleReadableIOCallback
+ - if we have NeedRead, or nghttp2_session_want_read, wait until the socket
+ becomes readable and call handleReadableIOCallback
+ */
+ if (hasPendingWrite()) {
+ updateIO(IOState::NeedWrite, handleWritableIOCallback);
+ }
+ else if (iostate == IOState::NeedWrite) {
+ updateIO(IOState::NeedWrite, handleReadableIOCallback);
+ }
+ else if (!d_connectionClosing) {
if (nghttp2_session_want_read(d_session.get()) != 0) {
- d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none);
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
}
- if (nghttp2_session_want_write(d_session.get()) != 0) {
- d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none);
+ else {
+ if (isIdle()) {
+ watchForRemoteHostClosingConnection();
+ }
}
}
}
catch (const std::exception& e) {
- vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ infolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
d_connectionDied = true;
stopIO();
}
}
-ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+void IncomingHTTP2Connection::writeToSocket(bool socketReady)
{
- auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
- conn->d_out.insert(conn->d_out.end(), data, data + length);
-
- if (conn->d_connectionDied || conn->d_needFlush) {
- try {
- conn->d_needFlush = false;
- auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
- if (state == IOState::Done) {
- conn->d_out.clear();
- conn->d_outPos = 0;
- if (!conn->isIdle()) {
- conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+ try {
+ d_needFlush = false;
+ IOState newState = d_handler.tryWrite(d_out, d_outPos, d_out.size());
+
+ if (newState == IOState::Done) {
+ d_pendingWrite = false;
+ d_out.clear();
+ d_outPos = 0;
+ if (active() && !d_connectionClosing) {
+ if (!isIdle()) {
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
}
else {
- conn->watchForRemoteHostClosingConnection();
+ watchForRemoteHostClosingConnection();
}
}
else {
- conn->updateIO(state, handleWritableIOCallback);
+ stopIO();
}
}
- catch (const std::exception& e) {
- vinfolog("Exception while trying to write (send) to incoming HTTP connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what());
- conn->handleIOError();
+ else {
+ updateIO(newState, handleWritableIOCallback);
+ d_pendingWrite = true;
}
}
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to write (%s) to HTTP client connection to %s: %s", (socketReady ? "ready" : "send"), d_ci.remote.toStringWithPort(), e.what());
+ handleIOError();
+ }
+}
+
+ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+ if (conn->d_connectionDied) {
+ return static_cast<ssize_t>(length);
+ }
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
+ conn->d_out.insert(conn->d_out.end(), data, data + length);
+
+ if (conn->d_connectionClosing || conn->d_needFlush) {
+ conn->writeToSocket(false);
+ }
return static_cast<ssize_t>(length);
}
sendResponse(response.d_idstate.d_streamID, context, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
handleResponseSent(response);
- return IOState::Done;
+ return hasPendingWrite() ? IOState::NeedWrite : IOState::Done;
}
void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response)
sendResponse(streamID, query, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
};
+ if (query.d_method == PendingQuery::Method::Unknown ||
+ query.d_method == PendingQuery::Method::Unsupported) {
+ handleImmediateResponse(400, "DoH query not allowed because of unsupported HTTP method");
+ return;
+ }
+
++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) {
int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
{
auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
-#if 0
- switch (frame->hd.type) {
- case NGHTTP2_HEADERS:
- cerr<<"got headers"<<endl;
- if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
- cerr<<"All headers received"<<endl;
- }
- if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
- cerr<<"All headers received - query"<<endl;
- }
- break;
- case NGHTTP2_WINDOW_UPDATE:
- cerr<<"got window update"<<endl;
- break;
- case NGHTTP2_SETTINGS:
- cerr<<"got settings"<<endl;
- cerr<<frame->settings.niv<<endl;
- for (size_t idx = 0; idx < frame->settings.niv; idx++) {
- cerr<<"- "<<frame->settings.iv[idx].settings_id<<" "<<frame->settings.iv[idx].value<<endl;
- }
- break;
- case NGHTTP2_DATA:
- cerr<<"got data"<<endl;
- break;
- }
-#endif
-
- if (frame->hd.type == NGHTTP2_GOAWAY) {
- conn->stopIO();
- if (conn->isIdle()) {
- if (nghttp2_session_want_write(conn->d_session.get()) != 0) {
- conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none);
- }
- }
- }
-
/* is this the last frame for this stream? */
- else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
+ if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
auto streamID = frame->hd.stream_id;
auto stream = conn->d_currentStreams.find(streamID);
if (stream != conn->d_currentStreams.end()) {
if (!insertPair.second) {
/* there is a stream ID collision, something is very wrong! */
vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
- conn->d_connectionDied = true;
+ conn->d_connectionClosing = true;
+ conn->d_needFlush = true;
nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
auto ret = nghttp2_session_send(conn->d_session.get());
if (ret != 0) {
query.d_method = PendingQuery::Method::Post;
}
else {
+ query.d_method = PendingQuery::Method::Unsupported;
vinfolog("Unsupported method value");
- return NGHTTP2_ERR_CALLBACK_FAILURE;
+ return 0;
}
}
auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len));
- conn->d_connectionDied = true;
+ conn->d_connectionClosing = true;
+ conn->d_needFlush = true;
nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
auto ret = nghttp2_session_send(conn->d_session.get());
if (ret != 0) {
return 0;
}
-void IncomingHTTP2Connection::readHTTPData()
+IOState IncomingHTTP2Connection::readHTTPData()
{
IOState newState = IOState::Done;
- IOStateGuard ioGuard(d_ioState);
- do {
- size_t got = 0;
- if (d_in.size() < 128) {
- d_in.resize(std::max(static_cast<size_t>(128U), d_in.capacity()));
- }
- try {
- newState = d_handler.tryRead(d_in, got, d_in.size(), true);
- d_in.resize(got);
-
- if (got > 0) {
- /* we got something */
- auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
- /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
- all data should be consumed before returning */
- if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
- throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
- }
-
- nghttp2_session_send(d_session.get());
+ size_t got = 0;
+ if (d_in.size() < s_initialReceiveBufferSize) {
+ d_in.resize(std::max(s_initialReceiveBufferSize, d_in.capacity()));
+ }
+ try {
+ newState = d_handler.tryRead(d_in, got, d_in.size(), true);
+ d_in.resize(got);
+
+ if (got > 0) {
+ /* we got something */
+ auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
+ /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
+ all data should be consumed before returning */
+ if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
+ throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
}
- if (newState == IOState::Done) {
- if (nghttp2_session_want_read(d_session.get()) != 0) {
- continue;
- }
- if (isIdle()) {
- watchForRemoteHostClosingConnection();
- ioGuard.release();
- break;
- }
- }
- else {
- if (newState == IOState::NeedWrite) {
- updateIO(IOState::NeedWrite, handleReadableIOCallback);
- }
- ioGuard.release();
- break;
- }
+ nghttp2_session_send(d_session.get());
}
- catch (const std::exception& e) {
- vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what());
- handleIOError();
- break;
- }
- } while (newState == IOState::Done || !isIdle());
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ handleIOError();
+ return IOState::Done;
+ }
+ return newState;
}
void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
void IncomingHTTP2Connection::handleWritableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
{
auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
- IOStateGuard ioGuard(conn->d_ioState);
-
- try {
- IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
- if (newState == IOState::NeedRead) {
- conn->updateIO(IOState::NeedRead, handleWritableIOCallback);
- }
- else if (newState == IOState::Done) {
- conn->d_out.clear();
- conn->d_outPos = 0;
- if (!conn->isIdle()) {
- conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
- }
- else {
- conn->watchForRemoteHostClosingConnection();
- }
- }
- ioGuard.release();
- }
- catch (const std::exception& e) {
- vinfolog("Exception while trying to write (ready) to HTTP client connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what());
- conn->handleIOError();
- }
+ conn->writeToSocket(true);
}
bool IncomingHTTP2Connection::isIdle() const
void IncomingHTTP2Connection::watchForRemoteHostClosingConnection()
{
- updateIO(IOState::NeedRead, handleReadableIOCallback);
+ if (d_connectionDied) {
+ return;
+ }
+
+ if (hasPendingWrite()) {
+ updateIO(IOState::NeedWrite, &handleWritableIOCallback);
+ }
+ else if (!d_connectionClosing) {
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
+ }
}
void IncomingHTTP2Connection::handleIOError()
{
d_connectionDied = true;
+ d_out.clear();
+ d_outPos = 0;
nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
d_currentStreams.clear();
stopIO();
}
+
+bool IncomingHTTP2Connection::active() const
+{
+ return !d_connectionDied && d_ioState != nullptr;
+}
+
#endif /* HAVE_NGHTTP2 */
{
Unknown,
Get,
- Post
+ Post,
+ Unsupported
};
PacketBuffer d_buffer;
void handleIO() override;
void handleResponse(const struct timeval& now, TCPResponse&& response) override;
void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
+ bool active() const override;
private:
static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
static void handleReadableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param);
static void handleWritableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param);
+ static constexpr size_t s_initialReceiveBufferSize{256U};
+
IOState sendResponse(const struct timeval& now, TCPResponse&& response) override;
bool forwardViaUDPFirst() const override
{
bool sendResponse(StreamID streamID, PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
void handleIncomingQuery(PendingQuery&& query, StreamID streamID);
bool checkALPN();
- void readHTTPData();
+ IOState readHTTPData();
void handleConnectionReady();
+ bool hasPendingWrite() const;
+ void writeToSocket(bool socketReady);
boost::optional<struct timeval> getIdleClientReadTTD(struct timeval now) const;
std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)> d_session{nullptr, nghttp2_session_del};
PacketBuffer d_out;
PacketBuffer d_in;
size_t d_outPos{0};
+ /* this connection is done, the remote end has closed the connection
+ or something like that. We do not want to try to write to it. */
bool d_connectionDied{false};
+ /* we are done reading from this connection, but we might still want to
+ write to it to close it properly */
+ bool d_connectionClosing{false};
+ /* Whether we are still waiting for more data to be buffered
+ before writing to the socket (false) or not. */
bool d_needFlush{false};
+ /* Whether we have data that we want to write to the socket,
+ but the socket is full. */
+ bool d_pendingWrite{false};
};
class NGHTTP2Headers
*/
nghttp2_data_provider data_provider;
- /* we will not use this pointer */
data_provider.source.ptr = this;
data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t {
- auto conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+ auto conn = static_cast<DoHConnectionToBackend*>(user_data);
auto& request = conn->d_currentStreams.at(stream_id);
size_t toCopy = 0;
if (request.d_queryPos < request.d_query.d_buffer.size()) {
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnswriter.hh"
+#include "dnsdist.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsdist-nghttp2-in.hh"
+
+#ifdef HAVE_NGHTTP2
+#include <nghttp2/nghttp2.h>
+
+extern std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+
+BOOST_AUTO_TEST_SUITE(test_dnsdistnghttp2_in_cc)
+
+struct ExpectedStep
+{
+public:
+ enum class ExpectedRequest
+ {
+ handshakeClient,
+ readFromClient,
+ writeToClient,
+ closeClient,
+ };
+
+ ExpectedStep(ExpectedRequest r, IOState n, size_t b = 0, std::function<void(int descriptor)> fn = nullptr) :
+ cb(fn), request(r), nextState(n), bytes(b)
+ {
+ }
+
+ std::function<void(int descriptor)> cb{nullptr};
+ ExpectedRequest request;
+ IOState nextState;
+ size_t bytes{0};
+};
+
+struct ExpectedData
+{
+ PacketBuffer d_proxyProtocolPayload;
+ std::vector<PacketBuffer> d_queries;
+ std::vector<PacketBuffer> d_responses;
+ std::vector<uint16_t> d_responseCodes;
+};
+
+class DOHConnection;
+
+static std::deque<ExpectedStep> s_steps;
+static std::map<uint64_t, ExpectedData> s_connectionContexts;
+static std::map<int, std::unique_ptr<DOHConnection>> s_connectionBuffers;
+static uint64_t s_connectionID{0};
+
+std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d);
+
+std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d)
+{
+ static const std::vector<std::string> requests = {"handshake with client", "read from client", "write to client", "close connection to client", "connect to the backend", "read from the backend", "write to the backend", "close connection to backend"};
+ os << requests.at(static_cast<size_t>(d));
+ return os;
+}
+
+class DOHConnection
+{
+public:
+ DOHConnection(uint64_t connectionID) :
+ d_session(std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(nullptr, nghttp2_session_del)), d_connectionID(connectionID)
+ {
+ const auto& context = s_connectionContexts.at(connectionID);
+ d_clientOutBuffer.insert(d_clientOutBuffer.begin(), context.d_proxyProtocolPayload.begin(), context.d_proxyProtocolPayload.end());
+
+ nghttp2_session_callbacks* cbs = nullptr;
+ nghttp2_session_callbacks_new(&cbs);
+ std::unique_ptr<nghttp2_session_callbacks, void (*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
+ cbs = nullptr;
+ nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
+ nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
+ nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
+ nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
+ nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
+ nghttp2_session* sess = nullptr;
+ nghttp2_session_client_new(&sess, callbacks.get(), this);
+ d_session = std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(sess, nghttp2_session_del);
+
+ nghttp2_settings_entry iv[] = {
+ /* rfc7540 section-8.2.2:
+ "Advertising a SETTINGS_MAX_CONCURRENT_STREAMS value of zero disables
+ server push by preventing the server from creating the necessary
+ streams."
+ */
+ {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 0},
+ {NGHTTP2_SETTINGS_ENABLE_PUSH, 0},
+ /* we might want to make the initial window size configurable, but 16M is a large enough default */
+ {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16 * 1024 * 1024}};
+ /* client 24 bytes magic string will be sent by nghttp2 library */
+ auto result = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv) / sizeof(*iv));
+ if (result != 0) {
+ throw std::runtime_error("Error submitting settings:" + std::string(nghttp2_strerror(result)));
+ }
+
+ const std::string host("unit-tests");
+ const std::string path("/dns-query");
+ for (const auto& query : context.d_queries) {
+ const auto querySize = std::to_string(query.size());
+ std::vector<nghttp2_nv> headers;
+ /* Pseudo-headers need to come first (rfc7540 8.1.2.1) */
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::METHOD_NAME, NGHTTP2Headers::HeaderConstantIndexes::METHOD_VALUE);
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_NAME, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_VALUE);
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::AUTHORITY_NAME, host);
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::PATH_NAME, path);
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_NAME, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_VALUE);
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_VALUE);
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_NAME, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_VALUE);
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_LENGTH_NAME, querySize);
+
+ d_position = 0;
+ d_currentQuery = query;
+ nghttp2_data_provider data_provider;
+ data_provider.source.ptr = this;
+ data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t {
+ auto* conn = static_cast<DOHConnection*>(user_data);
+ auto& pos = conn->d_position;
+ const auto& currentQuery = conn->d_currentQuery;
+ size_t toCopy = 0;
+ if (pos < currentQuery.size()) {
+ size_t remaining = currentQuery.size() - pos;
+ toCopy = length > remaining ? remaining : length;
+ memcpy(buf, ¤tQuery.at(pos), toCopy);
+ pos += toCopy;
+ }
+
+ if (pos >= currentQuery.size()) {
+ *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+ }
+ return toCopy;
+ };
+
+ auto newStreamId = nghttp2_submit_request(d_session.get(), nullptr, headers.data(), headers.size(), &data_provider, this);
+ if (newStreamId < 0) {
+ throw std::runtime_error("Error submitting HTTP request:" + std::string(nghttp2_strerror(newStreamId)));
+ }
+
+ result = nghttp2_session_send(d_session.get());
+ if (result != 0) {
+ throw std::runtime_error("Error in nghttp2_session_send:" + std::to_string(result));
+ }
+ }
+ }
+
+ std::map<int32_t, PacketBuffer> d_responses;
+ std::map<int32_t, uint16_t> d_responseCodes;
+ std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)> d_session;
+ PacketBuffer d_currentQuery;
+ PacketBuffer d_clientOutBuffer;
+ uint64_t d_connectionID{0};
+ size_t d_position{0};
+
+ size_t submitIncoming(const PacketBuffer& data, size_t pos, size_t toWrite)
+ {
+ ssize_t readlen = nghttp2_session_mem_recv(d_session.get(), &data.at(pos), toWrite);
+ if (readlen < 0) {
+ throw("Fatal error while submitting line " + std::to_string(__LINE__) + ": " + std::string(nghttp2_strerror(static_cast<int>(readlen))));
+ }
+
+ /* just in case, see if we have anything to send */
+ int rv = nghttp2_session_send(d_session.get());
+ if (rv != 0) {
+ throw("Fatal error while sending: " + std::string(nghttp2_strerror(rv)));
+ }
+
+ return readlen;
+ }
+
+private:
+ static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+ {
+ DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+ conn->d_clientOutBuffer.insert(conn->d_clientOutBuffer.end(), data, data + length);
+ return static_cast<ssize_t>(length);
+ }
+
+ static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+ {
+ DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+ if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) {
+ const auto& response = conn->d_responses.at(frame->hd.stream_id);
+ if (conn->d_responseCodes.at(frame->hd.stream_id) != 200U) {
+ return 0;
+ }
+
+ BOOST_REQUIRE_GT(response.size(), sizeof(dnsheader));
+ const auto* dh = reinterpret_cast<const dnsheader*>(response.data());
+ uint16_t id = ntohs(dh->id);
+
+ const auto& expected = s_connectionContexts.at(conn->d_connectionID).d_responses.at(id);
+ BOOST_REQUIRE_EQUAL(expected.size(), response.size());
+ for (size_t idx = 0; idx < response.size(); idx++) {
+ if (expected.at(idx) != response.at(idx)) {
+ cerr << "Mismatch at offset " << idx << ", expected " << std::to_string(response.at(idx)) << " got " << std::to_string(expected.at(idx)) << endl;
+ BOOST_CHECK(false);
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data)
+ {
+ DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+ auto& response = conn->d_responses[stream_id];
+ response.insert(response.end(), data, data + len);
+ return 0;
+ }
+
+ static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data)
+ {
+ DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+
+ const std::string status(":status");
+ if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
+ if (namelen == status.size() && memcmp(status.data(), name, status.size()) == 0) {
+ try {
+ uint16_t responseCode{0};
+ auto expected = s_connectionContexts.at(conn->d_connectionID).d_responseCodes.at((frame->hd.stream_id - 1) / 2);
+ pdns::checked_stoi_into(responseCode, std::string(reinterpret_cast<const char*>(value), valuelen));
+ conn->d_responseCodes[frame->hd.stream_id] = responseCode;
+ if (responseCode != expected) {
+ cerr << "Mismatch response code, expected " << std::to_string(expected) << " got " << std::to_string(responseCode) << endl;
+ BOOST_CHECK(false);
+ }
+ }
+ catch (const std::exception& e) {
+ infolog("Error parsing the status header for stream ID %d: %s", frame->hd.stream_id, e.what());
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ }
+ }
+ return 0;
+ }
+
+ static int on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data)
+ {
+ return 0;
+ }
+};
+
+class MockupTLSConnection : public TLSConnection
+{
+public:
+ MockupTLSConnection(int descriptor, [[maybe_unused]] bool client = false, [[maybe_unused]] bool needProxyProtocol = false) :
+ d_descriptor(descriptor)
+ {
+ auto connectionID = s_connectionID++;
+ auto conn = std::make_unique<DOHConnection>(connectionID);
+ s_connectionBuffers[d_descriptor] = std::move(conn);
+ }
+
+ ~MockupTLSConnection() {}
+
+ IOState tryHandshake() override
+ {
+ auto step = getStep();
+ BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::handshakeClient);
+
+ return step.nextState;
+ }
+
+ IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+ {
+ auto& conn = s_connectionBuffers.at(d_descriptor);
+ auto step = getStep();
+ BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::writeToClient);
+
+ if (step.bytes == 0) {
+ if (step.nextState == IOState::NeedWrite) {
+ return step.nextState;
+ }
+ throw std::runtime_error("Remote host closed the connection");
+ }
+
+ toWrite -= pos;
+ BOOST_REQUIRE_GE(buffer.size(), pos + toWrite);
+
+ if (step.bytes < toWrite) {
+ toWrite = step.bytes;
+ }
+
+ conn->submitIncoming(buffer, pos, toWrite);
+ pos += toWrite;
+
+ return step.nextState;
+ }
+
+ IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete = false) override
+ {
+ auto& conn = s_connectionBuffers.at(d_descriptor);
+ auto step = getStep();
+ BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::readFromClient);
+
+ if (step.bytes == 0) {
+ if (step.nextState == IOState::NeedRead) {
+ return step.nextState;
+ }
+ throw std::runtime_error("Remote host closed the connection");
+ }
+
+ auto& externalBuffer = conn->d_clientOutBuffer;
+ toRead -= pos;
+
+ if (step.bytes < toRead) {
+ toRead = step.bytes;
+ }
+ if (allowIncomplete) {
+ if (toRead > externalBuffer.size()) {
+ toRead = externalBuffer.size();
+ }
+ }
+ else {
+ BOOST_REQUIRE_GE(externalBuffer.size(), toRead);
+ }
+
+ BOOST_REQUIRE_GE(buffer.size(), toRead);
+
+ std::copy(externalBuffer.begin(), externalBuffer.begin() + toRead, buffer.begin() + pos);
+ pos += toRead;
+ externalBuffer.erase(externalBuffer.begin(), externalBuffer.begin() + toRead);
+
+ return step.nextState;
+ }
+
+ IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+ {
+ throw std::runtime_error("Should not happen");
+ }
+
+ void close() override
+ {
+ auto step = getStep();
+ BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::closeClient);
+ }
+
+ bool hasBufferedData() const override
+ {
+ return false;
+ }
+
+ bool isUsable() const override
+ {
+ return true;
+ }
+
+ std::string getServerNameIndication() const override
+ {
+ return "";
+ }
+
+ std::vector<uint8_t> getNextProtocol() const override
+ {
+ return std::vector<uint8_t>{'h', '2'};
+ }
+
+ LibsslTLSVersion getTLSVersion() const override
+ {
+ return LibsslTLSVersion::TLS13;
+ }
+
+ bool hasSessionBeenResumed() const override
+ {
+ return false;
+ }
+
+ std::vector<std::unique_ptr<TLSSession>> getSessions() override
+ {
+ return {};
+ }
+
+ void setSession(std::unique_ptr<TLSSession>& session) override
+ {
+ }
+
+ std::vector<int> getAsyncFDs() override
+ {
+ return {};
+ }
+
+ /* unused in that context, don't bother */
+ void doHandshake() override
+ {
+ }
+
+ void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
+ {
+ }
+
+ size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0, 0}, bool allowIncomplete = false) override
+ {
+ return 0;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
+ {
+ return 0;
+ }
+
+private:
+ ExpectedStep getStep() const
+ {
+ BOOST_REQUIRE(!s_steps.empty());
+ auto step = s_steps.front();
+ s_steps.pop_front();
+
+ if (step.cb) {
+ step.cb(d_descriptor);
+ }
+
+ return step;
+ }
+
+ const int d_descriptor;
+};
+
+#include "test-dnsdistnghttp2_common.hh"
+
+struct TestFixture
+{
+ TestFixture()
+ {
+ s_steps.clear();
+ s_connectionContexts.clear();
+ s_connectionBuffers.clear();
+ s_connectionID = 0;
+ s_mplexer = std::make_unique<MockupFDMultiplexer>();
+ }
+ ~TestFixture()
+ {
+ s_steps.clear();
+ s_connectionContexts.clear();
+ s_connectionBuffers.clear();
+ s_connectionID = 0;
+ s_mplexer.reset();
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_SelfAnswered, TestFixture)
+{
+ auto local = getBackendAddress("1", 80);
+ ClientState localCS(local, true, false, false, "", {});
+ localCS.dohFrontend = std::make_shared<DOHFrontend>(std::make_shared<MockupTLSCtx>());
+ localCS.dohFrontend->d_urls.insert("/dns-query");
+
+ TCPClientThreadData threadData;
+ threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
+
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
+ size_t counter = 0;
+ DNSName name("powerdns.com.");
+ PacketBuffer query;
+ GenericDNSPacketWriter<PacketBuffer> pwQ(query, name, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+ pwQ.getHeader()->id = htons(counter);
+
+ PacketBuffer response;
+ GenericDNSPacketWriter<PacketBuffer> pwR(response, name, QType::A, QClass::IN, 0);
+ pwR.getHeader()->qr = 1;
+ pwR.getHeader()->rd = 1;
+ pwR.getHeader()->ra = 1;
+ pwR.getHeader()->id = htons(counter);
+ pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfr32BitInt(0x01020304);
+ pwR.commit();
+
+ {
+ /* dnsdist drops the query right away after receiving it, client closes the connection */
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {403U}};
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* settings + headers + data client -> server.. */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+ /* .. continued */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+ /* headers + data */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max() },
+ /* wait for next query, but the client closes the connection */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ }
+
+ {
+ /* client closes the connection right in the middle of sending the query */
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, { 403U }};
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* client sends one byte */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 1 },
+ /* then closes the connection */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ /* mark the incoming FD as always ready */
+ dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+ threadData.mplexer->run(&now);
+ }
+ }
+
+ {
+ /* dnsdist sends a response right away, client closes the connection after getting the response */
+ s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+ /* self answered */
+ dq.getMutableData() = response;
+ return ProcessQueryResult::SendAnswer;
+ };
+
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* settings + headers + data client -> server.. */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+ /* .. continued */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+ /* headers + data */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max() },
+ /* wait for next query, but the client closes the connection */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ }
+
+ {
+ /* dnsdist sends a response right away, but the client closes the connection without even reading the response */
+ s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+ /* self answered */
+ dq.getMutableData() = response;
+ return ProcessQueryResult::SendAnswer;
+ };
+
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* settings + headers + data client -> server.. */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+ /* .. continued */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+ /* we want to send the response but the client closes the connection */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ /* mark the incoming FD as always ready */
+ dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+ threadData.mplexer->run(&now);
+ }
+ }
+
+ {
+ /* dnsdist sends a response right away, client closes the connection while getting the response */
+ s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+ /* self answered */
+ dq.getMutableData() = response;
+ return ProcessQueryResult::SendAnswer;
+ };
+
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* settings + headers + data client -> server.. */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+ /* .. continued */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+ /* headers + data (partial write) */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::NeedWrite, 1 },
+ /* nothing to read after that */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 },
+ /* then the client closes the connection before we are done */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ /* mark the incoming FD as always ready */
+ dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+ threadData.mplexer->run(&now);
+ }
+ }
+}
+
+BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_BackendTimeout, TestFixture)
+{
+ auto local = getBackendAddress("1", 80);
+ ClientState localCS(local, true, false, false, "", {});
+ localCS.dohFrontend = std::make_shared<DOHFrontend>(std::make_shared<MockupTLSCtx>());
+ localCS.dohFrontend->d_urls.insert("/dns-query");
+
+ TCPClientThreadData threadData;
+ threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
+
+ auto backend = std::make_shared<DownstreamState>(getBackendAddress("42", 53));
+
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
+ size_t counter = 0;
+ DNSName name("powerdns.com.");
+ PacketBuffer query;
+ GenericDNSPacketWriter<PacketBuffer> pwQ(query, name, QType::A, QClass::IN, 0);
+ pwQ.getHeader()->rd = 1;
+ pwQ.getHeader()->id = htons(counter);
+
+ PacketBuffer response;
+ GenericDNSPacketWriter<PacketBuffer> pwR(response, name, QType::A, QClass::IN, 0);
+ pwR.getHeader()->qr = 1;
+ pwR.getHeader()->rd = 1;
+ pwR.getHeader()->ra = 1;
+ pwR.getHeader()->id = htons(counter);
+ pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfr32BitInt(0x01020304);
+ pwR.commit();
+
+ {
+ /* dnsdist forwards the query to the backend, which does not answer -> timeout */
+ s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+ selectedBackend = backend;
+ return ProcessQueryResult::PassToBackend;
+ };
+ s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {502U}};
+ s_steps = {
+ /* opening */
+ { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+ /* settings server -> client */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+ /* settings + headers + data client -> server.. */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+ /* .. continued */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+ /* trying to read a new request while processing the first one */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead },
+ /* headers + data */
+ { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max(), [&threadData](int desc) {
+ /* set the incoming descriptor as ready */
+ dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(desc);
+ }
+ },
+ /* wait for next query, but the client closes the connection */
+ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+ /* server close */
+ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+ };
+
+ auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+ state->handleIO();
+ TCPResponse resp;
+ resp.d_idstate.d_streamID = 1;
+ state->notifyIOError(now, std::move(resp));
+ while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+ threadData.mplexer->run(&now);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
+#endif /* HAVE_NGHTTP2 */
bool d_client{false};
};
-class MockupTLSCtx : public TLSCtx
-{
-public:
- ~MockupTLSCtx()
- {
- }
-
- std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
- {
- return std::make_unique<MockupTLSConnection>(socket);
- }
-
- std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
- {
- return std::make_unique<MockupTLSConnection>(socket, true, d_needProxyProtocol);
- }
-
- void rotateTicketsKey(time_t now) override
- {
- }
-
- size_t getTicketsKeysCount() override
- {
- return 0;
- }
-
- std::string getName() const override
- {
- return "Mockup TLS";
- }
-
- bool d_needProxyProtocol{false};
-};
-
-class MockupFDMultiplexer : public FDMultiplexer
-{
-public:
- MockupFDMultiplexer()
- {
- }
-
- ~MockupFDMultiplexer()
- {
- }
-
- int run(struct timeval* tv, int timeout = 500) override
- {
- int ret = 0;
-
- gettimeofday(tv, nullptr); // MANDATORY
-
- /* 'ready' might be altered by a callback while we are iterating */
- const auto readyFDs = ready;
- for (const auto fd : readyFDs) {
- {
- const auto& it = d_readCallbacks.find(fd);
-
- if (it != d_readCallbacks.end()) {
- it->d_callback(it->d_fd, it->d_parameter);
- }
- }
-
- {
- const auto& it = d_writeCallbacks.find(fd);
-
- if (it != d_writeCallbacks.end()) {
- it->d_callback(it->d_fd, it->d_parameter);
- }
- }
- }
-
- return ret;
- }
-
- void getAvailableFDs(std::vector<int>& fds, int timeout) override
- {
- }
-
- void addFD(int fd, FDMultiplexer::EventKind kind) override
- {
- }
-
- void removeFD(int fd, FDMultiplexer::EventKind) override
- {
- }
-
- string getName() const override
- {
- return "mockup";
- }
-
- void setReady(int fd)
- {
- ready.insert(fd);
- }
-
- void setNotReady(int fd)
- {
- ready.erase(fd);
- }
-
-private:
- std::set<int> ready;
-};
+#include "test-dnsdistnghttp2_common.hh"
class MockupQuerySender : public TCPQuerySender
{
bool d_error{false};
};
-static bool isIPv6Supported()
-{
- try {
- ComboAddress addr("[2001:db8:53::1]:53");
- auto socket = std::make_unique<Socket>(addr.sin4.sin_family, SOCK_STREAM, 0);
- socket->setNonBlocking();
- int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0});
- if (res == 0 || res == EINPROGRESS) {
- return true;
- }
- return false;
- }
- catch (const std::exception& e) {
- return false;
- }
-}
-
-static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port)
-{
- static const bool useV6 = isIPv6Supported();
-
- if (useV6) {
- return ComboAddress("2001:db8:53::" + lastDigit, port);
- }
-
- return ComboAddress("192.0.2." + lastDigit, port);
-}
-
-static std::unique_ptr<FDMultiplexer> s_mplexer;
-
struct TestFixture
{
TestFixture()
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+class MockupTLSCtx : public TLSCtx
+{
+public:
+ ~MockupTLSCtx()
+ {
+ }
+
+ std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
+ {
+ return std::make_unique<MockupTLSConnection>(socket);
+ }
+
+ std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
+ {
+ return std::make_unique<MockupTLSConnection>(socket, true, d_needProxyProtocol);
+ }
+
+ void rotateTicketsKey(time_t now) override
+ {
+ }
+
+ size_t getTicketsKeysCount() override
+ {
+ return 0;
+ }
+
+ std::string getName() const override
+ {
+ return "Mockup TLS";
+ }
+
+ bool d_needProxyProtocol{false};
+};
+
+class MockupFDMultiplexer : public FDMultiplexer
+{
+public:
+ MockupFDMultiplexer()
+ {
+ }
+
+ ~MockupFDMultiplexer()
+ {
+ }
+
+ int run(struct timeval* tv, int timeout = 500) override
+ {
+ int ret = 0;
+
+ gettimeofday(tv, nullptr); // MANDATORY
+
+ /* 'ready' might be altered by a callback while we are iterating */
+ const auto readyFDs = ready;
+ for (const auto fd : readyFDs) {
+ {
+ const auto& it = d_readCallbacks.find(fd);
+
+ if (it != d_readCallbacks.end()) {
+ it->d_callback(it->d_fd, it->d_parameter);
+ }
+ }
+
+ {
+ const auto& it = d_writeCallbacks.find(fd);
+
+ if (it != d_writeCallbacks.end()) {
+ it->d_callback(it->d_fd, it->d_parameter);
+ }
+ }
+ }
+
+ return ret;
+ }
+
+ void getAvailableFDs(std::vector<int>& fds, int timeout) override
+ {
+ }
+
+ void addFD(int fd, FDMultiplexer::EventKind kind) override
+ {
+ }
+
+ void removeFD(int fd, FDMultiplexer::EventKind) override
+ {
+ }
+
+ string getName() const override
+ {
+ return "mockup";
+ }
+
+ void setReady(int fd)
+ {
+ ready.insert(fd);
+ }
+
+ void setNotReady(int fd)
+ {
+ ready.erase(fd);
+ }
+
+private:
+ std::set<int> ready;
+};
+
+static bool isIPv6Supported()
+{
+ try {
+ ComboAddress addr("[2001:db8:53::1]:53");
+ auto socket = std::make_unique<Socket>(addr.sin4.sin_family, SOCK_STREAM, 0);
+ socket->setNonBlocking();
+ int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0});
+ if (res == 0 || res == EINPROGRESS) {
+ return true;
+ }
+ return false;
+ }
+ catch (const std::exception& e) {
+ return false;
+ }
+}
+
+static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port)
+{
+ static const bool useV6 = isIPv6Supported();
+
+ if (useV6) {
+ return ComboAddress("2001:db8:53::" + lastDigit, port);
+ }
+
+ return ComboAddress("192.0.2." + lastDigit, port);
+}
+
+static std::unique_ptr<FDMultiplexer> s_mplexer;
{
}
-static std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
{
bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
{
- return false;
+ return true;
}
namespace dnsdist {
return conn
@classmethod
- def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+ def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, useProxyProtocol=False, conn=None):
url = cls.getDOHGetURL(baseurl, query, rawQuery)
- conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
- response_headers = BytesIO()
- #conn.setopt(pycurl.VERBOSE, True)
- conn.setopt(pycurl.URL, url)
- conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
- # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
- conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
+ if not conn:
+ print('creating a new connection')
+ conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+ # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
if useHTTPS:
+ print("disabling verify")
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)
+ if useProxyProtocol:
+ print('enabling PP')
+ # 274 is CURLOPT_HAPROXYPROTOCOL
+ conn.setopt(274, 1)
+
+ response_headers = BytesIO()
+ #conn.setopt(pycurl.VERBOSE, True)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
receivedQuery = None
message = None
cls._response_headers = ''
+ print('performing')
data = conn.perform_rb()
cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
if cls._rcode == 200 and not rawResponse:
cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)
- def sendDOHQueryWrapper(self, query, response, useQueue=True):
- return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+ def sendDOHQueryWrapper(self, query, response, useQueue=True, useProxyProtocol=False):
+ return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, useProxyProtocol=useProxyProtocol)
def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
#!/usr/bin/env python
+import base64
import dns
import os
import time
addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9"))
addAction("http-status-action.doh.tests.powerdns.com.", HTTPStatusAction(200, "Plaintext answer", "text/plain"))
addAction("http-status-action-redirect.doh.tests.powerdns.com.", HTTPStatusAction(307, "https://doh.powerdns.org"))
+ addAction("no-backend.doh.tests.powerdns.com.", PoolAction('this-pool-has-no-backend'))
function dohHandler(dq)
if dq:getHTTPScheme() == 'https' and dq:getHTTPHost() == '%s:%d' and dq:getHTTPPath() == '/' and dq:getHTTPQueryString() == '' then
(_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, caFile=self._caCert, query=query, response=None, useQueue=False)
self.assertEqual(receivedResponse, expectedResponse)
+ def testDOHWithoutQuery(self):
+ """
+ DOH: Empty GET query
+ """
+ name = 'empty-get.doh.tests.powerdns.com.'
+ url = self._dohBaseURL
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 400)
+
+ def testDOHShortPath(self):
+ """
+ DOH: Short path in GET query
+ """
+ name = 'short-path-get.doh.tests.powerdns.com.'
+ url = self._dohBaseURL + '/AA'
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 404)
+
+ def testDOHQueryNoParameter(self):
+ """
+ DOH: No parameter GET query
+ """
+ name = 'no-parameter-get.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ wire = query.to_wire()
+ b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+ url = self._dohBaseURL + '?not-dns=' + b64
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 400)
+
+ def testDOHQueryInvalidBase64(self):
+ """
+ DOH: Invalid Base64 GET query
+ """
+ name = 'invalid-b64-get.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ wire = query.to_wire()
+ url = self._dohBaseURL + '?dns=' + '_-~~~~-_'
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 400)
+
+ def testDOHInvalidDNSHeaders(self):
+ """
+ DOH: Invalid DNS headers
+ """
+ name = 'invalid-dns-headers.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.flags |= dns.flags.QR
+ wire = query.to_wire()
+ b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+ url = self._dohBaseURL + '?dns=' + b64
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 400)
+
+ def testDOHQueryInvalidMethod(self):
+ """
+ DOH: Invalid method
+ """
+ if self._dohLibrary == 'h2o':
+ raise unittest.SkipTest('h2o does not check the HTTP method')
+ name = 'invalid-method.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ wire = query.to_wire()
+ b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+ url = self._dohBaseURL + '?dns=' + b64
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ conn.setopt(pycurl.CAINFO, self._caCert)
+ conn.setopt(pycurl.CUSTOMREQUEST, 'PATCH')
+ data = conn.perform_rb()
+ rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ self.assertEqual(rcode, 400)
+
+ def testDOHQueryInvalidALPN(self):
+ """
+ DOH: Invalid ALPN
+ """
+ alpn = ['bogus-alpn']
+ conn = self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=alpn)
+ try:
+ conn.send('AAAA')
+ response = conn.recv(65535)
+ self.assertFalse(response)
+ except:
+ pass
+
def testDOHInvalid(self):
"""
- DOH: Invalid query
+ DOH: Invalid DNS query
"""
name = 'invalid.doh.tests.powerdns.com.'
invalidQuery = dns.message.make_query(name, 'A', 'IN', use_edns=False)
self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
self.assertEqual(response, receivedResponse)
- def testDOHWithoutQuery(self):
+ def testDOHInvalidHeaderName(self):
"""
- DOH: Empty GET query
+ DOH: Invalid HTTP header name query
"""
- name = 'empty-get.doh.tests.powerdns.com.'
- url = self._dohBaseURL
- conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+ name = 'invalid-header-name.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ expectedQuery.id = 0
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ # this header is invalid, see rfc9113 section 8.2.1. Field Validity
+ customHeaders = ['{}: test']
+ try:
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, customHeaders=customHeaders)
+ self.assertFalse(receivedQuery)
+ self.assertFalse(receivedResponse)
+ except pycurl.error:
+ pass
+
+ def testDOHNoBackend(self):
+ """
+ DOH: No backend
+ """
+ if self._dohLibrary == 'h2o':
+ raise unittest.SkipTest('h2o does not check the HTTP method')
+ name = 'no-backend.doh.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ wire = query.to_wire()
+ b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+ url = self._dohBaseURL + '?dns=' + b64
+ conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.CAINFO, self._caCert)
data = conn.perform_rb()
rcode = conn.getinfo(pycurl.RESPONSE_CODE)
- self.assertEqual(rcode, 400)
+ self.assertEqual(rcode, 403)
def testDOHEmptyPOST(self):
"""