From: Remi Gacogne Date: Wed, 12 Jul 2023 15:46:57 +0000 (+0200) Subject: dnsdist: Add support for incoming proxy protocol outside the TLS layer X-Git-Tag: rec-5.0.0-alpha1~19^2~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9f433af44e2bf4e2f1edb00affc76a362c38eee1;p=thirdparty%2Fpdns.git dnsdist: Add support for incoming proxy protocol outside the TLS layer --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index c09ff65aea..d5c27f4de3 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -2409,6 +2409,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) getOptionalValue(vars, "serverTokens", frontend->d_serverTokens); getOptionalValue(vars, "provider", frontend->d_tlsContext.d_provider); boost::algorithm::to_lower(frontend->d_tlsContext.d_provider); + getOptionalValue(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext.d_proxyProtocolOutsideTLS); LuaAssociativeTable customResponseHeaders; if (getOptionalValue(vars, "customResponseHeaders", customResponseHeaders) > 0) { @@ -2647,6 +2648,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) getOptionalValue(vars, "provider", frontend->d_provider); boost::algorithm::to_lower(frontend->d_provider); + getOptionalValue(vars, "proxyProtocolOutsideTLS", frontend->d_proxyProtocolOutsideTLS); LuaArray addresses; if (getOptionalValue(vars, "additionalAddresses", addresses) > 0) { diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 08ecf1d4fc..2f78c22dad 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -299,7 +299,7 @@ void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_p /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */ IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response) { - d_state = IncomingTCPConnectionState::State::sendingResponse; + d_state = State::sendingResponse; uint16_t responseSize = static_cast(response.d_buffer.size()); const uint8_t sizeBytes[] = { static_cast(responseSize / 256), static_cast(responseSize % 256) }; @@ -382,18 +382,18 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptrd_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { + if (state->d_state == State::idle || + state->d_state == State::waitingForQuery) { auto iostate = sendQueuedResponses(state, now); if (iostate == IOState::Done && state->active()) { if (state->canAcceptNewQueries(now)) { state->resetForNewQuery(); - state->d_state = IncomingTCPConnectionState::State::waitingForQuery; + state->d_state = State::waitingForQuery; iostate = IOState::NeedRead; } else { - state->d_state = IncomingTCPConnectionState::State::idle; + state->d_state = State::idle; } } @@ -649,7 +649,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); if (dnsCryptResponse) { TCPResponse response; - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; ++d_currentQueriesCount; queueResponse(state, now, std::move(response)); return QueryProcessingResult::SelfAnswered; @@ -668,7 +668,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha dh->qr = true; response.d_idstate.selfGenerated = true; response.d_buffer = std::move(query); - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; ++d_currentQueriesCount; queueResponse(state, now, std::move(response)); return QueryProcessingResult::Empty; @@ -749,7 +749,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha response.d_idstate.cs = d_ci.cs; response.d_buffer = std::move(query); - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; ++d_currentQueriesCount; queueResponse(state, now, std::move(response)); return QueryProcessingResult::SelfAnswered; @@ -849,7 +849,7 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand { do { DEBUGLOG("reading proxy protocol header"); - auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed); + auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS()); if (iostate == IOState::Done) { d_buffer.resize(d_currentPos); ssize_t remaining = isProxyHeaderComplete(d_buffer); @@ -887,6 +887,30 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand return ProxyProtocolResult::Reading; } +IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now) +{ + DEBUGLOG("doing handshake"); + auto iostate = d_handler.tryHandshake(); + if (iostate == IOState::Done) { + DEBUGLOG("handshake done"); + handleHandshakeDone(now); + + if (!isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + d_state = State::readingQuerySize; + } + } + else { + d_lastIOBlocked = true; + } + + return iostate; +} + void IncomingTCPConnectionState::handleIO() { // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read @@ -909,34 +933,34 @@ void IncomingTCPConnectionState::handleIO() d_lastIOBlocked = false; try { - if (d_state == IncomingTCPConnectionState::State::doingHandshake) { - DEBUGLOG("doing handshake"); - iostate = d_handler.tryHandshake(); - if (iostate == IOState::Done) { - DEBUGLOG("handshake done"); - handleHandshakeDone(now); - - if (expectProxyProtocolFrom(d_ci.remote)) { - d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; - d_buffer.resize(s_proxyProtocolMinimumHeaderSize); - d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; - } - else { - d_state = IncomingTCPConnectionState::State::readingQuerySize; - } + if (d_state == State::starting) { + if (isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; } else { - d_lastIOBlocked = true; + d_state = State::doingHandshake; } } - if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { + if (d_state == State::doingHandshake) { + iostate = handleHandshake(now); + } + + if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) { auto status = handleProxyProtocolPayload(); if (status == ProxyProtocolResult::Done) { - d_state = IncomingTCPConnectionState::State::readingQuerySize; - d_buffer.resize(sizeof(uint16_t)); - d_currentPos = 0; - d_proxyProtocolNeed = 0; + if (isProxyPayloadOutsideTLS()) { + d_state = State::doingHandshake; + iostate = handleHandshake(now); + } + else { + d_state = State::readingQuerySize; + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_proxyProtocolNeed = 0; + } } else if (status == ProxyProtocolResult::Error) { iostate = IOState::Done; @@ -946,19 +970,19 @@ void IncomingTCPConnectionState::handleIO() } } - if (!d_lastIOBlocked && (d_state == IncomingTCPConnectionState::State::waitingForQuery || - d_state == IncomingTCPConnectionState::State::readingQuerySize)) { + if (!d_lastIOBlocked && (d_state == State::waitingForQuery || + d_state == State::readingQuerySize)) { DEBUGLOG("reading query size"); d_buffer.resize(sizeof(uint16_t)); iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t)); if (d_currentPos > 0) { /* if we got at least one byte, we can't go around sending responses */ - d_state = IncomingTCPConnectionState::State::readingQuerySize; + d_state = State::readingQuerySize; } if (iostate == IOState::Done) { DEBUGLOG("query size received"); - d_state = IncomingTCPConnectionState::State::readingQuery; + d_state = State::readingQuery; d_querySizeReadTime = now; if (d_queriesCount == 0) { d_firstQuerySizeReadTime = now; @@ -980,14 +1004,14 @@ void IncomingTCPConnectionState::handleIO() } } - if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingQuery) { + if (!d_lastIOBlocked && d_state == State::readingQuery) { DEBUGLOG("reading query"); iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize); if (iostate == IOState::Done) { DEBUGLOG("query received"); d_buffer.resize(d_querySize); - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt); switch (processingResult) { case QueryProcessingResult::TooSmall: @@ -1005,7 +1029,7 @@ void IncomingTCPConnectionState::handleIO() /* the state might have been updated in the meantime, we don't want to override it in that case */ - if (active() && d_state != IncomingTCPConnectionState::State::idle) { + if (active() && d_state != State::idle) { if (d_ioState->isWaitingForRead()) { iostate = IOState::NeedRead; } @@ -1022,13 +1046,13 @@ void IncomingTCPConnectionState::handleIO() } } - if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::sendingResponse) { + if (!d_lastIOBlocked && d_state == State::sendingResponse) { DEBUGLOG("sending response"); iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); handleResponseSent(d_currentResponse); - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; } else { d_lastIOBlocked = true; @@ -1038,8 +1062,8 @@ void IncomingTCPConnectionState::handleIO() if (active() && !d_lastIOBlocked && iostate == IOState::Done && - (d_state == IncomingTCPConnectionState::State::idle || - d_state == IncomingTCPConnectionState::State::waitingForQuery)) + (d_state == State::idle || + d_state == State::waitingForQuery)) { // try sending queued responses DEBUGLOG("send responses, if any"); @@ -1054,19 +1078,19 @@ void IncomingTCPConnectionState::handleIO() iostate = IOState::NeedRead; } else { - d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; iostate = IOState::Done; } } } - if (d_state != IncomingTCPConnectionState::State::idle && - d_state != IncomingTCPConnectionState::State::doingHandshake && - d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader && - d_state != IncomingTCPConnectionState::State::waitingForQuery && - d_state != IncomingTCPConnectionState::State::readingQuerySize && - d_state != IncomingTCPConnectionState::State::readingQuery && - d_state != IncomingTCPConnectionState::State::sendingResponse) { + if (d_state != State::idle && + d_state != State::doingHandshake && + d_state != State::readingProxyProtocolHeader && + d_state != State::waitingForQuery && + d_state != State::readingQuerySize && + d_state != State::readingQuery && + d_state != State::sendingResponse) { vinfolog("Unexpected state %d in handleIOCallback", static_cast(d_state)); } } @@ -1075,18 +1099,18 @@ void IncomingTCPConnectionState::handleIO() but it might also be a real IO error or something else. Let's just drop the connection */ - if (d_state == IncomingTCPConnectionState::State::idle || - d_state == IncomingTCPConnectionState::State::waitingForQuery) { + if (d_state == State::idle || + d_state == State::waitingForQuery) { /* no need to increase any counters in that case, the client is simply done with us */ } - else if (d_state == IncomingTCPConnectionState::State::doingHandshake || - d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader || - d_state == IncomingTCPConnectionState::State::waitingForQuery || - d_state == IncomingTCPConnectionState::State::readingQuerySize || - d_state == IncomingTCPConnectionState::State::readingQuery) { + else if (d_state == State::doingHandshake || + d_state != State::readingProxyProtocolHeader || + d_state == State::waitingForQuery || + d_state == State::readingQuerySize || + d_state == State::readingQuery) { ++d_ci.cs->tcpDiedReadingQuery; } - else if (d_state == IncomingTCPConnectionState::State::sendingResponse) { + else if (d_state == State::sendingResponse) { /* unlikely to happen here, the exception should be handled in sendResponse() */ ++d_ci.cs->tcpDiedSendingResponse; } @@ -1180,7 +1204,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptrd_state = IncomingTCPConnectionState::State::idle; + state->d_state = State::idle; state->d_ioState->update(IOState::Done, handleIOCallback, state); } } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index a9ecef0170..694dab3a6c 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -529,6 +529,17 @@ struct ClientState return tlsFrontend != nullptr || (dohFrontend != nullptr && dohFrontend->isHTTPS()); } + const TLSFrontend& getTLSFrontend() const + { + if (tlsFrontend != nullptr) { + return *tlsFrontend; + } + if (dohFrontend) { + return dohFrontend->d_tlsContext; + } + throw std::runtime_error("Trying to get a TLS frontend from a non-TLS ClientState"); + } + dnsdist::Protocol getProtocol() const { if (dnscryptCtx) { diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc index 7945a2544f..b22da0c397 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -290,6 +290,32 @@ bool IncomingHTTP2Connection::hasPendingWrite() const return d_pendingWrite; } +IOState IncomingHTTP2Connection::handleHandshake(const struct timeval& now) +{ + auto iostate = d_handler.tryHandshake(); + if (iostate == IOState::Done) { + handleHandshakeDone(now); + if (d_handler.isTLS()) { + if (!checkALPN()) { + d_connectionDied = true; + stopIO(); + return iostate; + } + } + + if (!isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + d_state = State::waitingForQuery; + handleConnectionReady(); + } + } + return iostate; +} + void IncomingHTTP2Connection::handleIO() { IOState iostate = IOState::Done; @@ -306,39 +332,42 @@ void IncomingHTTP2Connection::handleIO() return; } + if (d_state == State::starting) { + if (isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + d_state = State::doingHandshake; + } + } + if (d_state == State::doingHandshake) { - iostate = d_handler.tryHandshake(); - if (iostate == IOState::Done) { - handleHandshakeDone(now); - if (d_handler.isTLS()) { - if (!checkALPN()) { - d_connectionDied = true; - stopIO(); + iostate = handleHandshake(now); + if (d_connectionDied) { + return; + } + } + + if (d_state == State::readingProxyProtocolHeader) { + auto status = handleProxyProtocolPayload(); + if (status == ProxyProtocolResult::Done) { + if (isProxyPayloadOutsideTLS()) { + d_state = State::doingHandshake; + iostate = handleHandshake(now); + if (d_connectionDied) { return; } } - - if (expectProxyProtocolFrom(d_ci.remote)) { - d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; - d_buffer.resize(s_proxyProtocolMinimumHeaderSize); - d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; - } else { + d_currentPos = 0; + d_proxyProtocolNeed = 0; + d_buffer.clear(); d_state = State::waitingForQuery; handleConnectionReady(); } } - } - - if (d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { - auto status = handleProxyProtocolPayload(); - if (status == ProxyProtocolResult::Done) { - d_currentPos = 0; - d_proxyProtocolNeed = 0; - d_buffer.clear(); - d_state = State::waitingForQuery; - handleConnectionReady(); - } else if (status == ProxyProtocolResult::Error) { d_connectionDied = true; stopIO(); diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh index 32af8d3087..a648a5027b 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh @@ -95,6 +95,7 @@ private: bool checkALPN(); IOState readHTTPData(); void handleConnectionReady(); + IOState handleHandshake(const struct timeval& now) override; bool hasPendingWrite() const; void writeToSocket(bool socketReady); boost::optional getIdleClientReadTTD(struct timeval now) const; diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 4318892659..f1c49e93e4 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -138,6 +138,7 @@ public: virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response); void handleResponseSent(TCPResponse& currentResponse); + virtual IOState handleHandshake(const struct timeval& now); void handleHandshakeDone(const struct timeval& now); ProxyProtocolResult handleProxyProtocolPayload(); void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response); @@ -150,6 +151,14 @@ public: { return d_ioState != nullptr; } + bool isProxyPayloadOutsideTLS() const + { + if (!d_ci.cs->hasTLS()) { + return false; + } + return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS; + } + virtual bool forwardViaUDPFirst() const { return false; @@ -174,7 +183,7 @@ public: dnsdist::Protocol getProtocol() const; - enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; + enum class State : uint8_t { starting, doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; TCPResponse d_currentResponse; std::map, std::deque>> d_ownedConnectionsToBackend; @@ -199,7 +208,7 @@ public: size_t d_currentQueriesCount{0}; std::thread::id d_creatorThreadID; uint16_t d_querySize{0}; - State d_state{State::doingHandshake}; + State d_state{State::starting}; bool d_isXFR{false}; bool d_proxyProtocolPayloadHasTLV{false}; bool d_lastIOBlocked{false}; diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 5e1d23e737..3cf674ca16 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -226,6 +226,8 @@ public: ComboAddress d_addr; std::string d_provider; ALPN d_alpn{ALPN::Unset}; + /* whether the proxy protocol is inside or outside the TLS layer */ + bool d_proxyProtocolOutsideTLS{false}; protected: std::shared_ptr d_ctx{nullptr}; }; @@ -365,13 +367,13 @@ public: return Done when toRead bytes have been read, needRead or needWrite if the IO operation would block. */ - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false) { if (buffer.size() < toRead || pos >= toRead) { throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos)); } - if (d_conn) { + if (!bypassFilters && d_conn) { return d_conn->tryRead(buffer, pos, toRead, allowIncomplete); }