]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Apply the max number of concurrent conns per client to DoH
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 27 Jan 2023 10:13:17 +0000 (11:13 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 27 Jan 2023 14:22:00 +0000 (15:22 +0100)
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-concurrent-connections.hh [new file with mode: 0644]
pdns/dnsdistdist/doh.cc

index 2df956334d1850f9bd95bdec694ed3b1e423d358..e7d8c2dfed15815f6943513a1eda6c3a63892d55 100644 (file)
@@ -37,6 +37,7 @@
 
 #include "dnsdist.hh"
 #include "dnsdist-carbon.hh"
+#include "dnsdist-concurrent-connections.hh"
 #include "dnsdist-console.hh"
 #include "dnsdist-dynblocks.hh"
 #include "dnsdist-discovery.hh"
@@ -1442,7 +1443,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
   luaCtx.writeFunction("setMaxTCPConnectionsPerClient", [](uint64_t max) {
     if (!g_configurationDone) {
-      g_maxTCPConnectionsPerClient = max;
+      dnsdist::IncomingConcurrentTCPConnectionsManager::setMaxTCPConnectionsPerClient(max);
     }
     else {
       g_outputBuffer = "The maximum number of TCP connection per client cannot be altered at runtime!\n";
index f701aaed26d4c6c1bcf1ceaba4778e8606c7cea5..7a781f15985812d03631ef37b9db488367f62bc6 100644 (file)
@@ -25,6 +25,7 @@
 #include <queue>
 
 #include "dnsdist.hh"
+#include "dnsdist-concurrent-connections.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
    Let's start naively.
 */
 
-static LockGuarded<std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan>> s_tcpClientsCount;
-
 size_t g_maxTCPQueriesPerConn{0};
 size_t g_maxTCPConnectionDuration{0};
-size_t g_maxTCPConnectionsPerClient{0};
+
 #ifdef __linux__
 // On Linux this gives us 128k pending queries (default is 8192 queries),
 // which should be enough to deal with huge spikes
@@ -76,20 +75,12 @@ int g_tcpRecvTimeout{2};
 int g_tcpSendTimeout{2};
 std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
 
-static void decrementTCPClientCount(const ComboAddress& client)
-{
-  if (g_maxTCPConnectionsPerClient) {
-    auto tcpClientsCount = s_tcpClientsCount.lock();
-    tcpClientsCount->at(client)--;
-    if (tcpClientsCount->at(client) == 0) {
-      tcpClientsCount->erase(client);
-    }
-  }
-}
+LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> dnsdist::IncomingConcurrentTCPConnectionsManager::s_tcpClientsConcurrentConnectionsCount;
+size_t dnsdist::IncomingConcurrentTCPConnectionsManager::s_maxTCPConnectionsPerClient = 0;
 
 IncomingTCPConnectionState::~IncomingTCPConnectionState()
 {
-  decrementTCPClientCount(d_ci.remote);
+  dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote);
 
   if (d_ci.cs != nullptr) {
     struct timeval now;
@@ -1462,16 +1453,11 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
       return;
     }
 
-    if (g_maxTCPConnectionsPerClient) {
-      auto tcpClientsCount = s_tcpClientsCount.lock();
-
-      if ((*tcpClientsCount)[remote] >= g_maxTCPConnectionsPerClient) {
-        vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
-        return;
-      }
-      (*tcpClientsCount)[remote]++;
-      tcpClientCountIncremented = true;
+    if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
+      vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
+      return;
     }
+    tcpClientCountIncremented = true;
 
     vinfolog("Got TCP connection from %s", remote.toStringWithPort());
 
@@ -1479,7 +1465,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
     if (threadData == nullptr) {
       if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) {
         if (tcpClientCountIncremented) {
-          decrementTCPClientCount(remote);
+          dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
         }
       }
     }
@@ -1493,7 +1479,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
   catch (const std::exception& e) {
     errlog("While reading a TCP question: %s", e.what());
     if (tcpClientCountIncremented) {
-      decrementTCPClientCount(remote);
+      dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
     }
   }
   catch (...){}
index 0741ca266ea901d5acd710ef10a84d36f83aa0d4..6b98fb29d5d19e43830fcc2a6c8943a917141f16 100644 (file)
@@ -1159,7 +1159,6 @@ extern boost::optional<uint64_t> g_maxTCPClientThreads;
 extern uint64_t g_maxTCPQueuedConnections;
 extern size_t g_maxTCPQueriesPerConn;
 extern size_t g_maxTCPConnectionDuration;
-extern size_t g_maxTCPConnectionsPerClient;
 extern size_t g_tcpInternalPipeBufferSize;
 extern pdns::stat16_t g_cacheCleaningDelay;
 extern pdns::stat16_t g_cacheCleaningPercentage;
index 7310e4050de0ad7023527064f0cf6930d7b68e46..c374d555f9cebb05c0549749c1abc6a7139231d7 100644 (file)
@@ -138,6 +138,7 @@ dnsdist_SOURCES = \
        dnsdist-backend.cc \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-carbon.cc dnsdist-carbon.hh \
+       dnsdist-concurrent-connections.hh \
        dnsdist-console.cc dnsdist-console.hh \
        dnsdist-discovery.cc dnsdist-discovery.hh \
        dnsdist-dnscrypt.cc \
@@ -247,6 +248,7 @@ testrunner_SOURCES = \
        dnsdist-async.cc dnsdist-async.hh \
        dnsdist-backend.cc \
        dnsdist-cache.cc dnsdist-cache.hh \
+       dnsdist-concurrent-connections.hh \
        dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
        dnsdist-downstream-connection.hh \
        dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
diff --git a/pdns/dnsdistdist/dnsdist-concurrent-connections.hh b/pdns/dnsdistdist/dnsdist-concurrent-connections.hh
new file mode 100644 (file)
index 0000000..3cf55d5
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <map>
+#include "iputils.hh"
+#include "lock.hh"
+
+namespace dnsdist
+{
+class IncomingConcurrentTCPConnectionsManager
+{
+public:
+  static bool accountNewTCPConnection(const ComboAddress& from)
+  {
+    if (s_maxTCPConnectionsPerClient == 0) {
+      return true;
+    }
+    auto db = s_tcpClientsConcurrentConnectionsCount.lock();
+    auto& count = (*db)[from];
+    if (count >= s_maxTCPConnectionsPerClient) {
+      return false;
+    }
+    ++count;
+    return true;
+  }
+
+  static void accountClosedTCPConnection(const ComboAddress& from)
+  {
+    if (s_maxTCPConnectionsPerClient == 0) {
+      return;
+    }
+    auto db = s_tcpClientsConcurrentConnectionsCount.lock();
+    auto& count = db->at(from);
+    count--;
+    if (count == 0) {
+      db->erase(from);
+    }
+  }
+
+  static void setMaxTCPConnectionsPerClient(size_t max)
+  {
+    s_maxTCPConnectionsPerClient = max;
+  }
+
+private:
+  static LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> s_tcpClientsConcurrentConnectionsCount;
+  static size_t s_maxTCPConnectionsPerClient;
+};
+
+}
index 58a6286978f1ee6a9212ff365ede998378477fd6..4ee3625c052b554242da8b0eda5ec184943ed7fb 100644 (file)
@@ -24,6 +24,7 @@
 #include "misc.hh"
 #include "dns.hh"
 #include "dolog.hh"
+#include "dnsdist-concurrent-connections.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rules.hh"
@@ -304,6 +305,7 @@ static void on_socketclose(void *data)
     }
 
     t_conns.erase(conn->d_desc);
+    dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(conn->d_remote);
   }
 }
 
@@ -1007,13 +1009,6 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req)
         ++dsc->cs->tlsResumptions;
       }
 
-      if (h2o_socket_getpeername(sock, reinterpret_cast<struct sockaddr*>(&conn.d_remote)) == 0) {
-        /* getpeername failed, likely because the connection has already been closed,
-           but anyway that means we can't get the remote address, which could allow an ACL bypass */
-        h2o_send_error_500(req, getReasonFromStatusCode(500).c_str(), "Internal Server Error - Unable to get remote address", 0);
-        return 0;
-      }
-
       h2o_socket_getsockname(sock, reinterpret_cast<struct sockaddr*>(&conn.d_local));
     }
 
@@ -1401,6 +1396,19 @@ static void on_accept(h2o_socket_t *listener, const char *err)
     return;
   }
 
+  ComboAddress remote;
+  if (h2o_socket_getpeername(sock, reinterpret_cast<struct sockaddr*>(&remote)) == 0) {
+    vinfolog("Dropping DoH connection because we could not retrieve the remote host");
+    h2o_socket_close(sock);
+    return;
+  }
+
+  if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
+    vinfolog("Dropping DoH connection from %s because we have too many from this client already", remote.toStringWithPort());
+    h2o_socket_close(sock);
+    return;
+  }
+
   auto concurrentConnections = ++dsc->cs->tcpCurrentConnections;
   if (dsc->cs->d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > dsc->cs->d_tcpConcurrentConnectionsLimit) {
     --dsc->cs->tcpCurrentConnections;
@@ -1418,6 +1426,7 @@ static void on_accept(h2o_socket_t *listener, const char *err)
   conn.d_nbQueries = 0;
   conn.d_acceptCtx = std::atomic_load_explicit(&dsc->accept_ctx, std::memory_order_acquire);
   conn.d_desc = descriptor;
+  conn.d_remote = remote;
 
   sock->on_close.cb = on_socketclose;
   sock->on_close.data = &conn;