From: Remi Gacogne Date: Thu, 21 Jan 2021 13:43:37 +0000 (+0100) Subject: dnsdist: Limit the number of concurrent web connections X-Git-Tag: dnsdist-1.6.0-alpha1~4^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5e40d2a5b3210144ba28a846e574bc94d6589296;p=thirdparty%2Fpdns.git dnsdist: Limit the number of concurrent web connections --- diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index 3d8b02c1f3..e4de181594 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -36,6 +36,7 @@ #include "ext/json11/json11.hpp" +#include "connection-management.hh" #include "dolog.hh" #include "dnsdist.hh" #include "dnsdist-console.hh" @@ -49,47 +50,14 @@ bool g_logConsoleConnections{true}; bool g_consoleEnabled{false}; uint32_t g_consoleOutputMsgMaxSize{10000000}; -class ConcurrentConnectionManager -{ -public: - static void setMaxConcurrentConnections(size_t max) - { - std::lock_guard lock(s_concurrentConsoleConnectionsLock); - s_maxConcurrentConsoleConnections = max; - } - - static bool registerConnection() - { - std::lock_guard lock(s_concurrentConsoleConnectionsLock); - if (s_maxConcurrentConsoleConnections == 0 || s_currentConnectionsCount < s_maxConcurrentConsoleConnections) { - ++s_currentConnectionsCount; - return true; - } - return false; - } - - static void releaseConnection() - { - std::lock_guard lock(s_concurrentConsoleConnectionsLock); - --s_currentConnectionsCount; - } - -private: - static std::mutex s_concurrentConsoleConnectionsLock; - static size_t s_maxConcurrentConsoleConnections; - static size_t s_currentConnectionsCount; -}; - -size_t ConcurrentConnectionManager::s_maxConcurrentConsoleConnections{100}; -size_t ConcurrentConnectionManager::s_currentConnectionsCount{0}; -std::mutex ConcurrentConnectionManager::s_concurrentConsoleConnectionsLock; +static ConcurrentConnectionManager s_connManager(100); class ConsoleConnection { public: ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd) { - if (!ConcurrentConnectionManager::registerConnection()) { + if (!s_connManager.registerConnection()) { close(fd); throw std::runtime_error("Too many concurrent console connections"); } @@ -106,7 +74,7 @@ public: { if (d_fd != -1) { close(d_fd); - ConcurrentConnectionManager::releaseConnection(); + s_connManager.releaseConnection(); } } @@ -127,7 +95,7 @@ private: void setConsoleMaximumConcurrentConnections(size_t max) { - ConcurrentConnectionManager::setMaxConcurrentConnections(max); + s_connManager.setMaxConcurrentConnections(max); } // MUST BE CALLED UNDER A LOCK - right now the LuaLock diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index a686c886e3..ee7ae516bf 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -897,7 +897,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) }); - typedef std::unordered_map> > webserveropts_t; + typedef std::unordered_map> > webserveropts_t; luaCtx.writeFunction("setWebserverConfig", [](boost::optional vars) { setLuaSideEffect(); @@ -933,6 +933,10 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) if (vars->count("statsRequireAuthentication")) { setWebserverStatsRequireAuthentication(boost::get(vars->at("statsRequireAuthentication"))); } + + if (vars->count("maxConcurrentConnections")) { + setWebserverMaxConcurrentConnections(boost::get(vars->at("maxConcurrentConnections"))); + } }); luaCtx.writeFunction("controlSocket", [client,configCheck](const std::string& str) { diff --git a/pdns/dnsdist-web.cc b/pdns/dnsdist-web.cc index 3f1b90c40e..303331eecd 100644 --- a/pdns/dnsdist-web.cc +++ b/pdns/dnsdist-web.cc @@ -30,6 +30,7 @@ #include #include "base64.hh" +#include "connection-management.hh" #include "dnsdist.hh" #include "dnsdist-dynblocks.hh" #include "dnsdist-healthchecks.hh" @@ -61,6 +62,46 @@ WebserverConfig g_webserverConfig; std::string g_apiConfigDirectory; static const MetricDefinitionStorage s_metricDefinitions; +static ConcurrentConnectionManager s_connManager(100); + +class WebClientConnection +{ +public: + WebClientConnection(const ComboAddress& client, int fd): d_client(client), d_socket(fd) + { + if (!s_connManager.registerConnection()) { + throw std::runtime_error("Too many concurrent web client connections"); + } + } + WebClientConnection(WebClientConnection&& rhs): d_client(rhs.d_client), d_socket(std::move(rhs.d_socket)) + { + } + + WebClientConnection(const WebClientConnection&) = delete; + WebClientConnection& operator=(const WebClientConnection&) = delete; + + ~WebClientConnection() + { + if (d_socket.getHandle() != -1) { + s_connManager.releaseConnection(); + } + } + + const Socket& getSocket() const + { + return d_socket; + } + + const ComboAddress& getClient() const + { + return d_client; + } + +private: + ComboAddress d_client; + Socket d_socket; +}; + const std::map MetricDefinitionStorage::metrics{ { "responses", MetricDefinition(PrometheusMetricType::counter, "Number of responses received from backends") }, { "servfail-responses", MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers received from backends") }, @@ -1246,14 +1287,11 @@ void registerBuiltInWebHandlers() } } -static void connectionThread(int sockFD, ComboAddress remote) +static void connectionThread(WebClientConnection&& conn) { setThreadName("dnsdist/webConn"); - vinfolog("Webserver handling connection from %s", remote.toStringWithPort()); - - Socket sock(sockFD); - sockFD = -1; + vinfolog("Webserver handling connection from %s", conn.getClient().toStringWithPort()); try { YaHTTP::AsyncRequestLoader yarl; @@ -1264,7 +1302,7 @@ static void connectionThread(int sockFD, ComboAddress remote) while (!finished) { int bytes; char buf[1024]; - bytes = read(sock.getHandle(), buf, sizeof(buf)); + bytes = read(conn.getSocket().getHandle(), buf, sizeof(buf)); if (bytes > 0) { string data = string(buf, bytes); finished = yarl.feed(data); @@ -1300,7 +1338,7 @@ static void connectionThread(int sockFD, ComboAddress remote) else if (!handleAuthorization(req)) { YaHTTP::strstr_map_t::iterator header = req.headers.find("authorization"); if (header != req.headers.end()) { - errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, remote.toStringWithPort()); + errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, conn.getClient().toStringWithPort()); } resp.status = 401; resp.body = "

Unauthorized

"; @@ -1322,16 +1360,16 @@ static void connectionThread(int sockFD, ComboAddress remote) std::ostringstream ofs; ofs << resp; string done = ofs.str(); - writen2(sock.getHandle(), done.c_str(), done.size()); + writen2(conn.getSocket().getHandle(), done.c_str(), done.size()); } catch (const YaHTTP::ParseError& e) { - vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", remote.toStringWithPort(), e.what()); + vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what()); } catch (const std::exception& e) { - errlog("Webserver thread died with exception while processing a request from %s: %s", remote.toStringWithPort(), e.what()); + errlog("Webserver thread died with exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what()); } catch (...) { - errlog("Webserver thread died with exception while processing a request from %s", remote.toStringWithPort()); + errlog("Webserver thread died with exception while processing a request from %s", conn.getClient().toStringWithPort()); } } @@ -1378,6 +1416,11 @@ void setWebserverStatsRequireAuthentication(bool require) g_webserverConfig.statsRequireAuthentication = require; } +void setWebserverMaxConcurrentConnections(size_t max) +{ + s_connManager.setMaxConcurrentConnections(max); +} + void dnsdistWebserverThread(int sock, const ComboAddress& local) { setThreadName("dnsdist/webserv"); @@ -1394,13 +1437,17 @@ void dnsdistWebserverThread(int sock, const ComboAddress& local) try { ComboAddress remote(local); int fd = SAccept(sock, remote); + if (!isClientAllowedByACL(remote)) { vinfolog("Connection to webserver from client %s is not allowed, closing", remote.toStringWithPort()); close(fd); continue; } + + WebClientConnection conn(remote, fd); vinfolog("Got a connection to the webserver from %s", remote.toStringWithPort()); - std::thread t(connectionThread, fd, remote); + + std::thread t(connectionThread, std::move(conn)); t.detach(); } catch (const std::exception& e) { diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index ad1576fc6f..da64c8c9b9 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -126,6 +126,7 @@ dnsdist_SOURCES = \ cachecleaner.hh \ capabilities.cc capabilities.hh \ circular_buffer.hh \ + connection-management.hh \ dns.cc dns.hh \ dnscrypt.cc dnscrypt.hh \ dnsdist-backend.cc \ diff --git a/pdns/dnsdistdist/connection-management.hh b/pdns/dnsdistdist/connection-management.hh new file mode 100644 index 0000000000..7dd4fd6b74 --- /dev/null +++ b/pdns/dnsdistdist/connection-management.hh @@ -0,0 +1,59 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include + +class ConcurrentConnectionManager +{ +public: + ConcurrentConnectionManager(size_t max): d_maxConcurrentConnections(max) + { + } + + void setMaxConcurrentConnections(size_t max) + { + std::lock_guard lock(d_concurrentConnectionsLock); + d_maxConcurrentConnections = max; + } + + bool registerConnection() + { + std::lock_guard lock(d_concurrentConnectionsLock); + if (d_maxConcurrentConnections == 0 || d_currentConnectionsCount < d_maxConcurrentConnections) { + ++d_currentConnectionsCount; + return true; + } + return false; + } + + void releaseConnection() + { + std::lock_guard lock(d_concurrentConnectionsLock); + --d_currentConnectionsCount; + } + +private: + std::mutex d_concurrentConnectionsLock; + size_t d_maxConcurrentConnections{0}; + size_t d_currentConnectionsCount{0}; +}; diff --git a/pdns/dnsdistdist/dnsdist-web.hh b/pdns/dnsdistdist/dnsdist-web.hh index d0d70e6821..ea869274c4 100644 --- a/pdns/dnsdistdist/dnsdist-web.hh +++ b/pdns/dnsdistdist/dnsdist-web.hh @@ -5,6 +5,7 @@ void setWebserverPassword(const std::string& password); void setWebserverACL(const std::string& acl); void setWebserverCustomHeaders(const boost::optional > customHeaders); void setWebserverStatsRequireAuthentication(bool); +void setWebserverMaxConcurrentConnections(size_t); void dnsdistWebserverThread(int sock, const ComboAddress& local); diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index a03ccdc0a9..c19c17fd8e 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -315,7 +315,7 @@ Webserver configuration ``acl`` optional parameter added. .. versionchanged:: 1.6.0 - ``statsRequireAuthentication`` optional parameter added. + ``statsRequireAuthentication``, ``maxConcurrentConnections`` optional parameters added. Setup webserver configuration. See :func:`webserver`. @@ -328,6 +328,7 @@ Webserver configuration * ``custom_headers={[str]=str,...}``: map of string - Allows setting custom headers and removing the defaults. * ``acl=newACL``: string - List of IP addresses, as a string, that are allowed to open a connection to the web server. Defaults to "127.0.0.1, ::1". * ``statsRequireAuthentication``: bool - Whether access to the statistics (/metrics and /jsonstat endpoints) require a valid password or API key. Defaults to true. + * ``maxConcurrentConnections``: int - The maximum number of concurrent web connections, or 0 which means an unlimited number. Defaults to 100. .. function:: registerWebHandler(path, handler) diff --git a/regression-tests.dnsdist/test_API.py b/regression-tests.dnsdist/test_API.py index b331ddb28a..461dc554f6 100644 --- a/regression-tests.dnsdist/test_API.py +++ b/regression-tests.dnsdist/test_API.py @@ -4,6 +4,8 @@ import os.path import base64 import json import requests +import socket +import time from dnsdisttests import DNSDistTest class TestAPIBasics(DNSDistTest): @@ -674,3 +676,44 @@ class TestCustomLuaEndpoint(DNSDistTest): self.assertEquals(r.status_code, 200) self.assertEquals(r.content, b'It works!') self.assertEquals(r.headers.get('foo'), "Bar") + +class TestWebConcurrentConnectionsL(DNSDistTest): + + _webTimeout = 2.0 + _webServerPort = 8083 + _webServerBasicAuthPassword = 'secret' + _webServerAPIKey = 'apisecret' + _maxConns = 2 + + _config_params = ['_testServerPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey', '_maxConns'] + _config_template = """ + newServer{address="127.0.0.1:%s"} + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s", maxConcurrentConnections=%d}) + """ + + def testConcurrentConnections(self): + """ + Web: Concurrent connections + """ + + conns = [] + # open the maximum number of connections + for _ in range(self._maxConns): + conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + conn.connect(("127.0.0.1", self._webServerPort)) + conns.append(conn) + + # we now hold all the slots, let's try to establish a new connection + url = 'http://127.0.0.1:' + str(self._webServerPort) + "/" + self.assertRaises(requests.exceptions.ConnectionError, requests.get, url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout) + + # free one slot + conns[0].close() + conns[0] = None + time.sleep(1) + + # this should work + r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout) + self.assertTrue(r) + self.assertEquals(r.status_code, 200)