]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add support for incoming proxy protocol outside the TLS layer
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 12 Jul 2023 15:46:57 +0000 (17:46 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 08:22:05 +0000 (10:22 +0200)
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-nghttp2-in.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/tcpiohandler.hh

index c09ff65aead40b6c0e999f715992dc421dc8fb4a..d5c27f4de39b7e889647f8f4837aae13af920b95 100644 (file)
@@ -2409,6 +2409,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       getOptionalValue<std::string>(vars, "serverTokens", frontend->d_serverTokens);
       getOptionalValue<std::string>(vars, "provider", frontend->d_tlsContext.d_provider);
       boost::algorithm::to_lower(frontend->d_tlsContext.d_provider);
+      getOptionalValue<bool>(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext.d_proxyProtocolOutsideTLS);
 
       LuaAssociativeTable<std::string> customResponseHeaders;
       if (getOptionalValue<decltype(customResponseHeaders)>(vars, "customResponseHeaders", customResponseHeaders) > 0) {
@@ -2647,6 +2648,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
       getOptionalValue<std::string>(vars, "provider", frontend->d_provider);
       boost::algorithm::to_lower(frontend->d_provider);
+      getOptionalValue<bool>(vars, "proxyProtocolOutsideTLS", frontend->d_proxyProtocolOutsideTLS);
 
       LuaArray<std::string> addresses;
       if (getOptionalValue<decltype(addresses)>(vars, "additionalAddresses", addresses) > 0) {
index 08ecf1d4fcaa6bba6bbfe636160ec51b8730b9c9..2f78c22dad7a81d8d20fec7a0bbe8c3825f356d6 100644 (file)
@@ -299,7 +299,7 @@ void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_p
 /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
 IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
 {
-  d_state = IncomingTCPConnectionState::State::sendingResponse;
+  d_state = State::sendingResponse;
 
   uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
   const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
@@ -382,18 +382,18 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnec
   // idle, and thus not trying to send the response right away would make our ref count go to 0.
   // Even if we are waiting for a query, we will not wake up before the new query arrives or a
   // timeout occurs
-  if (state->d_state == IncomingTCPConnectionState::State::idle ||
-      state->d_state == IncomingTCPConnectionState::State::waitingForQuery) {
+  if (state->d_state == State::idle ||
+      state->d_state == State::waitingForQuery) {
     auto iostate = sendQueuedResponses(state, now);
 
     if (iostate == IOState::Done && state->active()) {
       if (state->canAcceptNewQueries(now)) {
         state->resetForNewQuery();
-        state->d_state = IncomingTCPConnectionState::State::waitingForQuery;
+        state->d_state = State::waitingForQuery;
         iostate = IOState::NeedRead;
       }
       else {
-        state->d_state = IncomingTCPConnectionState::State::idle;
+        state->d_state = State::idle;
       }
     }
 
@@ -649,7 +649,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
   auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
   if (dnsCryptResponse) {
     TCPResponse response;
-    d_state = IncomingTCPConnectionState::State::idle;
+    d_state = State::idle;
     ++d_currentQueriesCount;
     queueResponse(state, now, std::move(response));
     return QueryProcessingResult::SelfAnswered;
@@ -668,7 +668,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
       dh->qr = true;
       response.d_idstate.selfGenerated = true;
       response.d_buffer = std::move(query);
-      d_state = IncomingTCPConnectionState::State::idle;
+      d_state = State::idle;
       ++d_currentQueriesCount;
       queueResponse(state, now, std::move(response));
       return QueryProcessingResult::Empty;
@@ -749,7 +749,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
     response.d_idstate.cs = d_ci.cs;
     response.d_buffer = std::move(query);
 
-    d_state = IncomingTCPConnectionState::State::idle;
+    d_state = State::idle;
     ++d_currentQueriesCount;
     queueResponse(state, now, std::move(response));
     return QueryProcessingResult::SelfAnswered;
@@ -849,7 +849,7 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand
 {
   do {
     DEBUGLOG("reading proxy protocol header");
-    auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed);
+    auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS());
     if (iostate == IOState::Done) {
       d_buffer.resize(d_currentPos);
       ssize_t remaining = isProxyHeaderComplete(d_buffer);
@@ -887,6 +887,30 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand
   return ProxyProtocolResult::Reading;
 }
 
+IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now)
+{
+  DEBUGLOG("doing handshake");
+  auto iostate = d_handler.tryHandshake();
+  if (iostate == IOState::Done) {
+    DEBUGLOG("handshake done");
+    handleHandshakeDone(now);
+
+    if (!isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+      d_state = State::readingProxyProtocolHeader;
+      d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+      d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+    }
+    else {
+      d_state = State::readingQuerySize;
+    }
+  }
+  else {
+    d_lastIOBlocked = true;
+  }
+
+  return iostate;
+}
+
 void IncomingTCPConnectionState::handleIO()
 {
   // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
@@ -909,34 +933,34 @@ void IncomingTCPConnectionState::handleIO()
     d_lastIOBlocked = false;
 
     try {
-      if (d_state == IncomingTCPConnectionState::State::doingHandshake) {
-        DEBUGLOG("doing handshake");
-        iostate = d_handler.tryHandshake();
-        if (iostate == IOState::Done) {
-          DEBUGLOG("handshake done");
-          handleHandshakeDone(now);
-
-          if (expectProxyProtocolFrom(d_ci.remote)) {
-            d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
-            d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
-            d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
-          }
-          else {
-            d_state = IncomingTCPConnectionState::State::readingQuerySize;
-          }
+      if (d_state == State::starting) {
+        if (isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+          d_state = State::readingProxyProtocolHeader;
+          d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+          d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
         }
         else {
-          d_lastIOBlocked = true;
+          d_state = State::doingHandshake;
         }
       }
 
-      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+      if (d_state == State::doingHandshake) {
+        iostate = handleHandshake(now);
+      }
+
+      if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) {
         auto status = handleProxyProtocolPayload();
         if (status == ProxyProtocolResult::Done) {
-          d_state = IncomingTCPConnectionState::State::readingQuerySize;
-          d_buffer.resize(sizeof(uint16_t));
-          d_currentPos = 0;
-          d_proxyProtocolNeed = 0;
+          if (isProxyPayloadOutsideTLS()) {
+            d_state = State::doingHandshake;
+            iostate = handleHandshake(now);
+          }
+          else {
+            d_state = State::readingQuerySize;
+            d_buffer.resize(sizeof(uint16_t));
+            d_currentPos = 0;
+            d_proxyProtocolNeed = 0;
+          }
         }
         else if (status == ProxyProtocolResult::Error) {
           iostate = IOState::Done;
@@ -946,19 +970,19 @@ void IncomingTCPConnectionState::handleIO()
         }
       }
 
-      if (!d_lastIOBlocked && (d_state == IncomingTCPConnectionState::State::waitingForQuery ||
-                                      d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
+      if (!d_lastIOBlocked && (d_state == State::waitingForQuery ||
+                                      d_state == State::readingQuerySize)) {
         DEBUGLOG("reading query size");
         d_buffer.resize(sizeof(uint16_t));
         iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
         if (d_currentPos > 0) {
           /* if we got at least one byte, we can't go around sending responses */
-          d_state = IncomingTCPConnectionState::State::readingQuerySize;
+          d_state = State::readingQuerySize;
         }
 
         if (iostate == IOState::Done) {
           DEBUGLOG("query size received");
-          d_state = IncomingTCPConnectionState::State::readingQuery;
+          d_state = State::readingQuery;
           d_querySizeReadTime = now;
           if (d_queriesCount == 0) {
             d_firstQuerySizeReadTime = now;
@@ -980,14 +1004,14 @@ void IncomingTCPConnectionState::handleIO()
         }
       }
 
-      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingQuery) {
+      if (!d_lastIOBlocked && d_state == State::readingQuery) {
         DEBUGLOG("reading query");
         iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
         if (iostate == IOState::Done) {
           DEBUGLOG("query received");
           d_buffer.resize(d_querySize);
 
-          d_state = IncomingTCPConnectionState::State::idle;
+          d_state = State::idle;
           auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
           switch (processingResult) {
           case QueryProcessingResult::TooSmall:
@@ -1005,7 +1029,7 @@ void IncomingTCPConnectionState::handleIO()
 
           /* the state might have been updated in the meantime, we don't want to override it
              in that case */
-          if (active() && d_state != IncomingTCPConnectionState::State::idle) {
+          if (active() && d_state != State::idle) {
             if (d_ioState->isWaitingForRead()) {
               iostate = IOState::NeedRead;
             }
@@ -1022,13 +1046,13 @@ void IncomingTCPConnectionState::handleIO()
         }
       }
 
-      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      if (!d_lastIOBlocked && d_state == State::sendingResponse) {
         DEBUGLOG("sending response");
         iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
         if (iostate == IOState::Done) {
           DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
           handleResponseSent(d_currentResponse);
-          d_state = IncomingTCPConnectionState::State::idle;
+          d_state = State::idle;
         }
         else {
           d_lastIOBlocked = true;
@@ -1038,8 +1062,8 @@ void IncomingTCPConnectionState::handleIO()
       if (active() &&
           !d_lastIOBlocked &&
           iostate == IOState::Done &&
-          (d_state == IncomingTCPConnectionState::State::idle ||
-           d_state == IncomingTCPConnectionState::State::waitingForQuery))
+          (d_state == State::idle ||
+           d_state == State::waitingForQuery))
       {
         // try sending queued responses
         DEBUGLOG("send responses, if any");
@@ -1054,19 +1078,19 @@ void IncomingTCPConnectionState::handleIO()
             iostate = IOState::NeedRead;
           }
           else {
-            d_state = IncomingTCPConnectionState::State::idle;
+            d_state = State::idle;
             iostate = IOState::Done;
           }
         }
       }
 
-      if (d_state != IncomingTCPConnectionState::State::idle &&
-          d_state != IncomingTCPConnectionState::State::doingHandshake &&
-          d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
-          d_state != IncomingTCPConnectionState::State::waitingForQuery &&
-          d_state != IncomingTCPConnectionState::State::readingQuerySize &&
-          d_state != IncomingTCPConnectionState::State::readingQuery &&
-          d_state != IncomingTCPConnectionState::State::sendingResponse) {
+      if (d_state != State::idle &&
+          d_state != State::doingHandshake &&
+          d_state != State::readingProxyProtocolHeader &&
+          d_state != State::waitingForQuery &&
+          d_state != State::readingQuerySize &&
+          d_state != State::readingQuery &&
+          d_state != State::sendingResponse) {
         vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
       }
     }
@@ -1075,18 +1099,18 @@ void IncomingTCPConnectionState::handleIO()
          but it might also be a real IO error or something else.
          Let's just drop the connection
       */
-      if (d_state == IncomingTCPConnectionState::State::idle ||
-          d_state == IncomingTCPConnectionState::State::waitingForQuery) {
+      if (d_state == State::idle ||
+          d_state == State::waitingForQuery) {
         /* no need to increase any counters in that case, the client is simply done with us */
       }
-      else if (d_state == IncomingTCPConnectionState::State::doingHandshake ||
-               d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
-               d_state == IncomingTCPConnectionState::State::waitingForQuery ||
-               d_state == IncomingTCPConnectionState::State::readingQuerySize ||
-               d_state == IncomingTCPConnectionState::State::readingQuery) {
+      else if (d_state == State::doingHandshake ||
+               d_state != State::readingProxyProtocolHeader ||
+               d_state == State::waitingForQuery ||
+               d_state == State::readingQuerySize ||
+               d_state == State::readingQuery) {
         ++d_ci.cs->tcpDiedReadingQuery;
       }
-      else if (d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      else if (d_state == State::sendingResponse) {
         /* unlikely to happen here, the exception should be handled in sendResponse() */
         ++d_ci.cs->tcpDiedSendingResponse;
       }
@@ -1180,7 +1204,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnec
   else {
     DEBUGLOG("Going idle");
     /* we still have some queries in flight, let's just stop reading for now */
-    state->d_state = IncomingTCPConnectionState::State::idle;
+    state->d_state = State::idle;
     state->d_ioState->update(IOState::Done, handleIOCallback, state);
   }
 }
index a9ecef0170809c6af5cfc0053e2d84bed317563e..694dab3a6cbeb74328c218ee5a50370809fb07dc 100644 (file)
@@ -529,6 +529,17 @@ struct ClientState
     return tlsFrontend != nullptr || (dohFrontend != nullptr && dohFrontend->isHTTPS());
   }
 
+  const TLSFrontend& getTLSFrontend() const
+  {
+    if (tlsFrontend != nullptr) {
+      return *tlsFrontend;
+    }
+    if (dohFrontend) {
+      return dohFrontend->d_tlsContext;
+    }
+    throw std::runtime_error("Trying to get a TLS frontend from a non-TLS ClientState");
+  }
+
   dnsdist::Protocol getProtocol() const
   {
     if (dnscryptCtx) {
index 7945a2544f304493661092670ff5ceb9af487c97..b22da0c3971572ff25e4e6a4c0864577430c2f53 100644 (file)
@@ -290,6 +290,32 @@ bool IncomingHTTP2Connection::hasPendingWrite() const
   return d_pendingWrite;
 }
 
+IOState IncomingHTTP2Connection::handleHandshake(const struct timeval& now)
+{
+  auto iostate = d_handler.tryHandshake();
+  if (iostate == IOState::Done) {
+    handleHandshakeDone(now);
+    if (d_handler.isTLS()) {
+      if (!checkALPN()) {
+        d_connectionDied = true;
+        stopIO();
+        return iostate;
+      }
+    }
+
+    if (!isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+      d_state = State::readingProxyProtocolHeader;
+      d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+      d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+    }
+    else {
+      d_state = State::waitingForQuery;
+      handleConnectionReady();
+    }
+  }
+  return iostate;
+}
+
 void IncomingHTTP2Connection::handleIO()
 {
   IOState iostate = IOState::Done;
@@ -306,39 +332,42 @@ void IncomingHTTP2Connection::handleIO()
       return;
     }
 
+    if (d_state == State::starting) {
+      if (isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+        d_state = State::readingProxyProtocolHeader;
+        d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+        d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+      }
+      else {
+        d_state = State::doingHandshake;
+      }
+    }
+
     if (d_state == State::doingHandshake) {
-      iostate = d_handler.tryHandshake();
-      if (iostate == IOState::Done) {
-        handleHandshakeDone(now);
-        if (d_handler.isTLS()) {
-          if (!checkALPN()) {
-            d_connectionDied = true;
-            stopIO();
+      iostate = handleHandshake(now);
+      if (d_connectionDied) {
+        return;
+      }
+    }
+
+    if (d_state == State::readingProxyProtocolHeader) {
+      auto status = handleProxyProtocolPayload();
+      if (status == ProxyProtocolResult::Done) {
+        if (isProxyPayloadOutsideTLS()) {
+          d_state = State::doingHandshake;
+          iostate = handleHandshake(now);
+          if (d_connectionDied) {
             return;
           }
         }
-
-        if (expectProxyProtocolFrom(d_ci.remote)) {
-          d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
-          d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
-          d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
-        }
         else {
+          d_currentPos = 0;
+          d_proxyProtocolNeed = 0;
+          d_buffer.clear();
           d_state = State::waitingForQuery;
           handleConnectionReady();
         }
       }
-    }
-
-    if (d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
-      auto status = handleProxyProtocolPayload();
-      if (status == ProxyProtocolResult::Done) {
-        d_currentPos = 0;
-        d_proxyProtocolNeed = 0;
-        d_buffer.clear();
-        d_state = State::waitingForQuery;
-        handleConnectionReady();
-      }
       else if (status == ProxyProtocolResult::Error) {
         d_connectionDied = true;
         stopIO();
index 32af8d3087d4fffd940bc62a69485f3969191b0a..a648a5027b3f15e7a3a6eac8082eaaddb3b445fe 100644 (file)
@@ -95,6 +95,7 @@ private:
   bool checkALPN();
   IOState readHTTPData();
   void handleConnectionReady();
+  IOState handleHandshake(const struct timeval& now) override;
   bool hasPendingWrite() const;
   void writeToSocket(bool socketReady);
   boost::optional<struct timeval> getIdleClientReadTTD(struct timeval now) const;
index 4318892659f6233b8dfc172f9d93bba040e24601..f1c49e93e4e2ede74c98f39b945ac559f78ff925 100644 (file)
@@ -138,6 +138,7 @@ public:
 
   virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response);
   void handleResponseSent(TCPResponse& currentResponse);
+  virtual IOState handleHandshake(const struct timeval& now);
   void handleHandshakeDone(const struct timeval& now);
   ProxyProtocolResult handleProxyProtocolPayload();
   void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
@@ -150,6 +151,14 @@ public:
   {
     return d_ioState != nullptr;
   }
+  bool isProxyPayloadOutsideTLS() const
+  {
+    if (!d_ci.cs->hasTLS()) {
+      return false;
+    }
+    return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS;
+  }
+
   virtual bool forwardViaUDPFirst() const
   {
     return false;
@@ -174,7 +183,7 @@ public:
 
   dnsdist::Protocol getProtocol() const;
 
-  enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
+  enum class State : uint8_t { starting, doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
 
   TCPResponse d_currentResponse;
   std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_ownedConnectionsToBackend;
@@ -199,7 +208,7 @@ public:
   size_t d_currentQueriesCount{0};
   std::thread::id d_creatorThreadID;
   uint16_t d_querySize{0};
-  State d_state{State::doingHandshake};
+  State d_state{State::starting};
   bool d_isXFR{false};
   bool d_proxyProtocolPayloadHasTLV{false};
   bool d_lastIOBlocked{false};
index 5e1d23e737c6dd613f8d701e95a1f64e323d15d8..3cf674ca16562fa5a3aa48d41ba0ce9210f3af01 100644 (file)
@@ -226,6 +226,8 @@ public:
   ComboAddress d_addr;
   std::string d_provider;
   ALPN d_alpn{ALPN::Unset};
+  /* whether the proxy protocol is inside or outside the TLS layer */
+  bool d_proxyProtocolOutsideTLS{false};
 protected:
   std::shared_ptr<TLSCtx> d_ctx{nullptr};
 };
@@ -365,13 +367,13 @@ public:
      return Done when toRead bytes have been read, needRead or needWrite if the IO operation
      would block.
   */
-  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false)
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false)
   {
     if (buffer.size() < toRead || pos >= toRead) {
       throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
     }
 
-    if (d_conn) {
+    if (!bypassFilters && d_conn) {
       return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
     }