]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdistdist/doh3.cc
dnsdist: Use the correct source IP for outgoing QUIC datagrams
[thirdparty/pdns.git] / pdns / dnsdistdist / doh3.cc
index 26b3cf5686ae25faef783713f82105fccbf96169..7ff4748f1551521e65daa445cbec8d17d0b9932a 100644 (file)
@@ -54,8 +54,8 @@ using h3_headers_t = std::map<std::string, std::string>;
 class H3Connection
 {
 public:
-  H3Connection(const ComboAddress& peer, QuicheConfig config, QuicheConnection&& conn) :
-    d_peer(peer), d_conn(std::move(conn)), d_config(std::move(config))
+  H3Connection(const ComboAddress& peer, const ComboAddress& localAddr, QuicheConfig config, QuicheConnection&& conn) :
+    d_peer(peer), d_localAddr(localAddr), d_conn(std::move(conn)), d_config(std::move(config))
   {
   }
   H3Connection(const H3Connection&) = delete;
@@ -65,6 +65,7 @@ public:
   ~H3Connection() = default;
 
   ComboAddress d_peer;
+  ComboAddress d_localAddr;
   QuicheConnection d_conn;
   QuicheConfig d_config;
   QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
@@ -421,14 +422,14 @@ static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description)
   }
 }
 
-static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer)
+static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& localAddr, const ComboAddress& peer)
 {
   auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire);
   auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
                                                    originalDestinationID.data(), originalDestinationID.size(),
                                                    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-                                                   reinterpret_cast<const struct sockaddr*>(&local),
-                                                   local.getSocklen(),
+                                                   reinterpret_cast<const struct sockaddr*>(&localAddr),
+                                                   localAddr.getSocklen(),
                                                    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
                                                    reinterpret_cast<const struct sockaddr*>(&peer),
                                                    peer.getSocklen(),
@@ -439,7 +440,7 @@ static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3
     quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str());
   }
 
-  auto conn = H3Connection(peer, std::move(quicheConfig), std::move(quicheConn));
+  auto conn = H3Connection(peer, localAddr, std::move(quicheConfig), std::move(quicheConn));
   auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
   return pair.first->second;
 }
@@ -743,7 +744,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
       return;
     }
     DEBUGLOG("Dispatching GET query");
-    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID);
+    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID);
     conn.d_streamBuffers.erase(streamID);
     conn.d_headersBuffers.erase(streamID);
     return;
@@ -808,7 +809,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
   }
 
   DEBUGLOG("Dispatching POST query");
-  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
   conn.d_headersBuffers.erase(streamID);
   conn.d_streamBuffers.erase(streamID);
 }
@@ -856,10 +857,21 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
   PacketBuffer tokenBuf;
   while (true) {
     ComboAddress client;
+    ComboAddress localAddr;
+    client.sin4.sin_family = clientState.local.sin4.sin_family;
+    localAddr.sin4.sin_family = clientState.local.sin4.sin_family;
     buffer.resize(4096);
-    if (!sock.recvFromAsync(buffer, client) || buffer.empty()) {
+    if (!dnsdist::doq::recvAsync(sock, buffer, client, localAddr)) {
       return;
     }
+    if (localAddr.sin4.sin_family == 0) {
+      localAddr = clientState.local;
+    }
+    else {
+      /* we don't get the port, only the address */
+      localAddr.sin4.sin_port = clientState.local.sin4.sin_port;
+    }
+
     DEBUGLOG("Received DoH3 datagram of size " << buffer.size() << " from " << client.toStringWithPort());
 
     uint32_t version{0};
@@ -896,14 +908,14 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
       if (!quiche_version_is_supported(version)) {
         DEBUGLOG("Unsupported version");
         ++frontend.d_doh3UnsupportedVersionErrors;
-        handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer);
+        handleVersionNegociation(sock, clientConnID, serverConnID, client, localAddr, buffer);
         continue;
       }
 
       if (token_len == 0) {
         /* stateless retry */
         DEBUGLOG("No token received");
-        handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer);
+        handleStatelessRetry(sock, clientConnID, serverConnID, client, localAddr, version, buffer);
         continue;
       }
 
@@ -916,7 +928,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
       }
 
       DEBUGLOG("Creating a new connection");
-      conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client);
+      conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, localAddr, client);
       if (!conn) {
         continue;
       }
@@ -927,8 +939,8 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
       reinterpret_cast<struct sockaddr*>(&client),
       client.getSocklen(),
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      reinterpret_cast<struct sockaddr*>(&clientState.local),
-      clientState.local.getSocklen(),
+      reinterpret_cast<struct sockaddr*>(&localAddr),
+      localAddr.getSocklen(),
     };
 
     auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
@@ -950,7 +962,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
 
       processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer);
 
-      flushEgress(sock, conn->get().d_conn, client, buffer);
+      flushEgress(sock, conn->get().d_conn, client, localAddr, buffer);
     }
     else {
       DEBUGLOG("Connection not established");
@@ -995,7 +1007,7 @@ void doh3Thread(ClientState* clientState)
         for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
           quiche_conn_on_timeout(conn->second.d_conn.get());
 
-          flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer);
+          flushEgress(sock, conn->second.d_conn, conn->second.d_peer, conn->second.d_localAddr, buffer);
 
           if (quiche_conn_is_closed(conn->second.d_conn.get())) {
 #ifdef DEBUGLOG_ENABLED