]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Limit the number of concurrent web connections
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 21 Jan 2021 13:43:37 +0000 (14:43 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 29 Jan 2021 09:57:51 +0000 (10:57 +0100)
pdns/dnsdist-console.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-web.cc
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/connection-management.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-web.hh
pdns/dnsdistdist/docs/reference/config.rst
regression-tests.dnsdist/test_API.py

index 3d8b02c1f3a063a36a69d53b04b376d368b8e335..e4de1815947c54f8c9a367257963b8ea1953b4c6 100644 (file)
@@ -36,6 +36,7 @@
 
 #include "ext/json11/json11.hpp"
 
+#include "connection-management.hh"
 #include "dolog.hh"
 #include "dnsdist.hh"
 #include "dnsdist-console.hh"
@@ -49,47 +50,14 @@ bool g_logConsoleConnections{true};
 bool g_consoleEnabled{false};
 uint32_t g_consoleOutputMsgMaxSize{10000000};
 
-class ConcurrentConnectionManager
-{
-public:
-  static void setMaxConcurrentConnections(size_t max)
-  {
-    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
-    s_maxConcurrentConsoleConnections = max;
-  }
-
-  static bool registerConnection()
-  {
-    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
-    if (s_maxConcurrentConsoleConnections == 0 || s_currentConnectionsCount < s_maxConcurrentConsoleConnections) {
-      ++s_currentConnectionsCount;
-      return true;
-    }
-    return false;
-  }
-
-  static void releaseConnection()
-  {
-    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
-    --s_currentConnectionsCount;
-  }
-
-private:
-  static std::mutex s_concurrentConsoleConnectionsLock;
-  static size_t s_maxConcurrentConsoleConnections;
-  static size_t s_currentConnectionsCount;
-};
-
-size_t ConcurrentConnectionManager::s_maxConcurrentConsoleConnections{100};
-size_t ConcurrentConnectionManager::s_currentConnectionsCount{0};
-std::mutex ConcurrentConnectionManager::s_concurrentConsoleConnectionsLock;
+static ConcurrentConnectionManager s_connManager(100);
 
 class ConsoleConnection
 {
 public:
   ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
   {
-    if (!ConcurrentConnectionManager::registerConnection()) {
+    if (!s_connManager.registerConnection()) {
       close(fd);
       throw std::runtime_error("Too many concurrent console connections");
     }
@@ -106,7 +74,7 @@ public:
   {
     if (d_fd != -1) {
       close(d_fd);
-      ConcurrentConnectionManager::releaseConnection();
+      s_connManager.releaseConnection();
     }
   }
 
@@ -127,7 +95,7 @@ private:
 
 void setConsoleMaximumConcurrentConnections(size_t max)
 {
-  ConcurrentConnectionManager::setMaxConcurrentConnections(max);
+  s_connManager.setMaxConcurrentConnections(max);
 }
 
 // MUST BE CALLED UNDER A LOCK - right now the LuaLock
index a686c886e3453f5075275ca667d6419cb1fd8a4e..ee7ae516bfffba7cb989fa09aa96e61d2cd777ad 100644 (file)
@@ -897,7 +897,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
     });
 
-  typedef std::unordered_map<std::string, boost::variant<bool, std::string, std::map<std::string, std::string>> > webserveropts_t;
+  typedef std::unordered_map<std::string, boost::variant<bool, size_t, std::string, std::map<std::string, std::string>> > webserveropts_t;
 
   luaCtx.writeFunction("setWebserverConfig", [](boost::optional<webserveropts_t> vars) {
       setLuaSideEffect();
@@ -933,6 +933,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       if (vars->count("statsRequireAuthentication")) {
         setWebserverStatsRequireAuthentication(boost::get<bool>(vars->at("statsRequireAuthentication")));
       }
+
+      if (vars->count("maxConcurrentConnections")) {
+        setWebserverMaxConcurrentConnections(boost::get<size_t>(vars->at("maxConcurrentConnections")));
+      }
     });
 
   luaCtx.writeFunction("controlSocket", [client,configCheck](const std::string& str) {
index 3f1b90c40e7ee6cecbd92bfb56d7d391e5426428..303331eecdad920db5c4fe98127fad9abc2c6399 100644 (file)
@@ -30,6 +30,7 @@
 #include <yahttp/yahttp.hpp>
 
 #include "base64.hh"
+#include "connection-management.hh"
 #include "dnsdist.hh"
 #include "dnsdist-dynblocks.hh"
 #include "dnsdist-healthchecks.hh"
@@ -61,6 +62,46 @@ WebserverConfig g_webserverConfig;
 std::string g_apiConfigDirectory;
 static const MetricDefinitionStorage s_metricDefinitions;
 
+static ConcurrentConnectionManager s_connManager(100);
+
+class WebClientConnection
+{
+public:
+  WebClientConnection(const ComboAddress& client, int fd): d_client(client), d_socket(fd)
+  {
+    if (!s_connManager.registerConnection()) {
+      throw std::runtime_error("Too many concurrent web client connections");
+    }
+  }
+  WebClientConnection(WebClientConnection&& rhs): d_client(rhs.d_client), d_socket(std::move(rhs.d_socket))
+  {
+  }
+
+  WebClientConnection(const WebClientConnection&) = delete;
+  WebClientConnection& operator=(const WebClientConnection&) = delete;
+
+  ~WebClientConnection()
+  {
+    if (d_socket.getHandle() != -1) {
+      s_connManager.releaseConnection();
+    }
+  }
+
+  const Socket& getSocket() const
+  {
+    return d_socket;
+  }
+
+  const ComboAddress& getClient() const
+  {
+    return d_client;
+  }
+
+private:
+  ComboAddress d_client;
+  Socket d_socket;
+};
+
 const std::map<std::string, MetricDefinition> MetricDefinitionStorage::metrics{
   { "responses",              MetricDefinition(PrometheusMetricType::counter, "Number of responses received from backends") },
   { "servfail-responses",     MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers received from backends") },
@@ -1246,14 +1287,11 @@ void registerBuiltInWebHandlers()
   }
 }
 
-static void connectionThread(int sockFD, ComboAddress remote)
+static void connectionThread(WebClientConnection&& conn)
 {
   setThreadName("dnsdist/webConn");
 
-  vinfolog("Webserver handling connection from %s", remote.toStringWithPort());
-
-  Socket sock(sockFD);
-  sockFD = -1;
+  vinfolog("Webserver handling connection from %s", conn.getClient().toStringWithPort());
 
   try {
     YaHTTP::AsyncRequestLoader yarl;
@@ -1264,7 +1302,7 @@ static void connectionThread(int sockFD, ComboAddress remote)
     while (!finished) {
       int bytes;
       char buf[1024];
-      bytes = read(sock.getHandle(), buf, sizeof(buf));
+      bytes = read(conn.getSocket().getHandle(), buf, sizeof(buf));
       if (bytes > 0) {
         string data = string(buf, bytes);
         finished = yarl.feed(data);
@@ -1300,7 +1338,7 @@ static void connectionThread(int sockFD, ComboAddress remote)
     else if (!handleAuthorization(req)) {
       YaHTTP::strstr_map_t::iterator header = req.headers.find("authorization");
       if (header != req.headers.end()) {
-        errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, remote.toStringWithPort());
+        errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, conn.getClient().toStringWithPort());
       }
       resp.status = 401;
       resp.body = "<h1>Unauthorized</h1>";
@@ -1322,16 +1360,16 @@ static void connectionThread(int sockFD, ComboAddress remote)
     std::ostringstream ofs;
     ofs << resp;
     string done = ofs.str();
-    writen2(sock.getHandle(), done.c_str(), done.size());
+    writen2(conn.getSocket().getHandle(), done.c_str(), done.size());
   }
   catch (const YaHTTP::ParseError& e) {
-    vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", remote.toStringWithPort(), e.what());
+    vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what());
   }
   catch (const std::exception& e) {
-    errlog("Webserver thread died with exception while processing a request from %s: %s", remote.toStringWithPort(), e.what());
+    errlog("Webserver thread died with exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what());
   }
   catch (...) {
-    errlog("Webserver thread died with exception while processing a request from %s", remote.toStringWithPort());
+    errlog("Webserver thread died with exception while processing a request from %s", conn.getClient().toStringWithPort());
   }
 }
 
@@ -1378,6 +1416,11 @@ void setWebserverStatsRequireAuthentication(bool require)
   g_webserverConfig.statsRequireAuthentication = require;
 }
 
+void setWebserverMaxConcurrentConnections(size_t max)
+{
+  s_connManager.setMaxConcurrentConnections(max);
+}
+
 void dnsdistWebserverThread(int sock, const ComboAddress& local)
 {
   setThreadName("dnsdist/webserv");
@@ -1394,13 +1437,17 @@ void dnsdistWebserverThread(int sock, const ComboAddress& local)
     try {
       ComboAddress remote(local);
       int fd = SAccept(sock, remote);
+
       if (!isClientAllowedByACL(remote)) {
         vinfolog("Connection to webserver from client %s is not allowed, closing", remote.toStringWithPort());
         close(fd);
         continue;
       }
+
+      WebClientConnection conn(remote, fd);
       vinfolog("Got a connection to the webserver from %s", remote.toStringWithPort());
-      std::thread t(connectionThread, fd, remote);
+
+      std::thread t(connectionThread, std::move(conn));
       t.detach();
     }
     catch (const std::exception& e) {
index ad1576fc6f7b176580344422a0c33e32c3af2610..da64c8c9b9fedff13c09e581e7e7a2055458e8f8 100644 (file)
@@ -126,6 +126,7 @@ dnsdist_SOURCES = \
        cachecleaner.hh \
        capabilities.cc capabilities.hh \
        circular_buffer.hh \
+       connection-management.hh \
        dns.cc dns.hh \
        dnscrypt.cc dnscrypt.hh \
        dnsdist-backend.cc \
diff --git a/pdns/dnsdistdist/connection-management.hh b/pdns/dnsdistdist/connection-management.hh
new file mode 100644 (file)
index 0000000..7dd4fd6
--- /dev/null
@@ -0,0 +1,59 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <mutex>
+
+class ConcurrentConnectionManager
+{
+public:
+  ConcurrentConnectionManager(size_t max): d_maxConcurrentConnections(max)
+  {
+  }
+
+  void setMaxConcurrentConnections(size_t max)
+  {
+    std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+    d_maxConcurrentConnections = max;
+  }
+
+  bool registerConnection()
+  {
+    std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+    if (d_maxConcurrentConnections == 0 || d_currentConnectionsCount < d_maxConcurrentConnections) {
+      ++d_currentConnectionsCount;
+      return true;
+    }
+    return false;
+  }
+
+  void releaseConnection()
+  {
+    std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+    --d_currentConnectionsCount;
+  }
+
+private:
+  std::mutex d_concurrentConnectionsLock;
+  size_t d_maxConcurrentConnections{0};
+  size_t d_currentConnectionsCount{0};
+};
index d0d70e6821c335619d1c2cf0602a1165012b65b9..ea869274c475de44d612bc3ab620c34f0b0479a5 100644 (file)
@@ -5,6 +5,7 @@ void setWebserverPassword(const std::string& password);
 void setWebserverACL(const std::string& acl);
 void setWebserverCustomHeaders(const boost::optional<std::map<std::string, std::string> > customHeaders);
 void setWebserverStatsRequireAuthentication(bool);
+void setWebserverMaxConcurrentConnections(size_t);
 
 void dnsdistWebserverThread(int sock, const ComboAddress& local);
 
index a03ccdc0a9029ffe1491d6d686aab3456942d3bf..c19c17fd8e2cc8757104bb166d3faa5bfa2cf629 100644 (file)
@@ -315,7 +315,7 @@ Webserver configuration
     ``acl`` optional parameter added.
 
   .. versionchanged:: 1.6.0
-    ``statsRequireAuthentication`` optional parameter added.
+    ``statsRequireAuthentication``, ``maxConcurrentConnections`` optional parameters added.
 
   Setup webserver configuration. See :func:`webserver`.
 
@@ -328,6 +328,7 @@ Webserver configuration
   * ``custom_headers={[str]=str,...}``: map of string - Allows setting custom headers and removing the defaults.
   * ``acl=newACL``: string - List of IP addresses, as a string, that are allowed to open a connection to the web server. Defaults to "127.0.0.1, ::1".
   * ``statsRequireAuthentication``: bool - Whether access to the statistics (/metrics and /jsonstat endpoints) require a valid password or API key. Defaults to true.
+  * ``maxConcurrentConnections``: int - The maximum number of concurrent web connections, or 0 which means an unlimited number. Defaults to 100.
 
 .. function:: registerWebHandler(path, handler)
 
index b331ddb28a867bd305a7e3b0c24f62eaa30d41fb..461dc554f6a0aeb58552b61842273b97c74d7386 100644 (file)
@@ -4,6 +4,8 @@ import os.path
 import base64
 import json
 import requests
+import socket
+import time
 from dnsdisttests import DNSDistTest
 
 class TestAPIBasics(DNSDistTest):
@@ -674,3 +676,44 @@ class TestCustomLuaEndpoint(DNSDistTest):
         self.assertEquals(r.status_code, 200)
         self.assertEquals(r.content, b'It works!')
         self.assertEquals(r.headers.get('foo'), "Bar")
+
+class TestWebConcurrentConnectionsL(DNSDistTest):
+
+    _webTimeout = 2.0
+    _webServerPort = 8083
+    _webServerBasicAuthPassword = 'secret'
+    _webServerAPIKey = 'apisecret'
+    _maxConns = 2
+
+    _config_params = ['_testServerPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey', '_maxConns']
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s", maxConcurrentConnections=%d})
+    """
+
+    def testConcurrentConnections(self):
+        """
+        Web: Concurrent connections
+        """
+
+        conns = []
+        # open the maximum number of connections
+        for _ in range(self._maxConns):
+            conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            conn.connect(("127.0.0.1", self._webServerPort))
+            conns.append(conn)
+
+        # we now hold all the slots, let's try to establish a new connection
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + "/"
+        self.assertRaises(requests.exceptions.ConnectionError, requests.get, url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
+
+        # free one slot
+        conns[0].close()
+        conns[0] = None
+        time.sleep(1)
+
+        # this should work
+        r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
+        self.assertTrue(r)
+        self.assertEquals(r.status_code, 200)