]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Make a separate DoQ connections map per bind
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 27 Sep 2023 23:28:49 +0000 (01:28 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Oct 2023 11:38:07 +0000 (13:38 +0200)
pdns/dnsdistdist/doq.cc

index 1ba7a8a638ac94f5d9c85920822697e241757f54..a90d90e5e1df83e57b302df035a96ac5a868f693 100644 (file)
@@ -85,7 +85,10 @@ struct DOQServerConfig
   DOQServerConfig& operator=(DOQServerConfig&&) = default;
   ~DOQServerConfig() = default;
 
+  using ConnectionsMap = std::map<PacketBuffer, Connection>;
+
   LocalHolders holders;
+  ConnectionsMap d_connections;
   QuicheConfig config;
   ClientState* clientState{nullptr};
   std::shared_ptr<DOQFrontend> df{nullptr};
@@ -114,8 +117,6 @@ DOQFrontend::~DOQFrontend()
 static constexpr size_t MAX_DATAGRAM_SIZE = 1200;
 static constexpr size_t LOCAL_CONN_ID_LEN = 16;
 
-static std::map<PacketBuffer, Connection> s_connections;
-
 class DOQTCPCrossQuerySender final : public TCPQuerySender
 {
 public:
@@ -489,13 +490,13 @@ static void handleVersionNegociation(Socket& sock, const PacketBuffer& clientCon
   sock.sendTo(reinterpret_cast<const char*>(out.data()), written, peer);
 }
 
-static std::optional<std::reference_wrapper<Connection>> getConnection(const PacketBuffer& id)
+static std::optional<std::reference_wrapper<Connection>> getConnection(DOQServerConfig::ConnectionsMap& connMap, const PacketBuffer& connID)
 {
-  auto it = s_connections.find(id);
-  if (it == s_connections.end()) {
+  auto iter = connMap.find(connID);
+  if (iter == connMap.end()) {
     return std::nullopt;
   }
-  return it->second;
+  return iter->second;
 }
 
 static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description)
@@ -514,7 +515,7 @@ static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description)
   }
 }
 
-static std::optional<std::reference_wrapper<Connection>> createConnection(const DOQServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const PacketBuffer& token, const ComboAddress& local, const ComboAddress& peer)
+static std::optional<std::reference_wrapper<Connection>> createConnection(DOQServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer)
 {
   auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
                                                    originalDestinationID.data(), originalDestinationID.size(),
@@ -530,7 +531,7 @@ static std::optional<std::reference_wrapper<Connection>> createConnection(const
   }
 
   auto conn = Connection(peer, std::move(quicheConn));
-  auto pair = s_connections.emplace(serverSideID, std::move(conn));
+  auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
   return pair.first->second;
 }
 
@@ -588,7 +589,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& unit)
 {
   const auto handleImmediateResponse = [](DOQUnitUniquePtr&& du, const char* reason) {
     DEBUGLOG("handleImmediateResponse() reason=" << reason);
-    auto conn = getConnection(du->serverConnID);
+    auto conn = getConnection(du->dsc->df->d_server_config->d_connections, du->serverConnID);
     handleResponse(*du->dsc->df, *conn, du->streamID, du->response);
     du->ids.doqu.reset();
   };
@@ -757,10 +758,11 @@ static void flushResponses(pdns::channel::Receiver<DOQUnit>& receiver)
         return;
       }
 
-      auto du = std::move(*tmp);
-      auto conn = getConnection(du->serverConnID);
-
-      handleResponse(*du->dsc->df, *conn, du->streamID, du->response);
+      auto unit = std::move(*tmp);
+      auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
+      if (conn) {
+        handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->response);
+      }
     }
     catch (const std::exception& e) {
       errlog("Error while processing response received over DoQ: %s", e.what());
@@ -821,7 +823,7 @@ void doqThread(ClientState* clientState)
         PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
         // source connection ID, will have to be sent as destination connection ID
         PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
-        auto conn = getConnection(serverConnID);
+        auto conn = getConnection(frontend->d_server_config->d_connections, serverConnID);
 
         if (!conn) {
           DEBUGLOG("Connection not found");
@@ -848,7 +850,7 @@ void doqThread(ClientState* clientState)
           }
 
           DEBUGLOG("Creating a new connection");
-          conn = createConnection(*frontend->d_server_config, serverConnID, *originalDestinationID, tokenBuf, clientState->local, client);
+          conn = createConnection(*frontend->d_server_config, serverConnID, *originalDestinationID, clientState->local, client);
           if (!conn) {
             continue;
           }
@@ -908,7 +910,7 @@ void doqThread(ClientState* clientState)
         flushResponses(frontend->d_server_config->d_responseReceiver);
       }
 
-      for (auto conn = s_connections.begin(); conn != s_connections.end();) {
+      for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
         quiche_conn_on_timeout(conn->second.d_conn.get());
 
         flushEgress(sock, conn->second);
@@ -923,7 +925,7 @@ void doqThread(ClientState* clientState)
 
           DEBUGLOG("Connection closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
 #endif
-          conn = s_connections.erase(conn);
+          conn = frontend->d_server_config->d_connections.erase(conn);
         }
         else {
           ++conn;