#include "ext/json11/json11.hpp"
+#include "connection-management.hh"
#include "dolog.hh"
#include "dnsdist.hh"
#include "dnsdist-console.hh"
bool g_consoleEnabled{false};
uint32_t g_consoleOutputMsgMaxSize{10000000};
-class ConcurrentConnectionManager
-{
-public:
- static void setMaxConcurrentConnections(size_t max)
- {
- std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
- s_maxConcurrentConsoleConnections = max;
- }
-
- static bool registerConnection()
- {
- std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
- if (s_maxConcurrentConsoleConnections == 0 || s_currentConnectionsCount < s_maxConcurrentConsoleConnections) {
- ++s_currentConnectionsCount;
- return true;
- }
- return false;
- }
-
- static void releaseConnection()
- {
- std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
- --s_currentConnectionsCount;
- }
-
-private:
- static std::mutex s_concurrentConsoleConnectionsLock;
- static size_t s_maxConcurrentConsoleConnections;
- static size_t s_currentConnectionsCount;
-};
-
-size_t ConcurrentConnectionManager::s_maxConcurrentConsoleConnections{100};
-size_t ConcurrentConnectionManager::s_currentConnectionsCount{0};
-std::mutex ConcurrentConnectionManager::s_concurrentConsoleConnectionsLock;
+static ConcurrentConnectionManager s_connManager(100);
class ConsoleConnection
{
public:
ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
{
- if (!ConcurrentConnectionManager::registerConnection()) {
+ if (!s_connManager.registerConnection()) {
close(fd);
throw std::runtime_error("Too many concurrent console connections");
}
{
if (d_fd != -1) {
close(d_fd);
- ConcurrentConnectionManager::releaseConnection();
+ s_connManager.releaseConnection();
}
}
void setConsoleMaximumConcurrentConnections(size_t max)
{
- ConcurrentConnectionManager::setMaxConcurrentConnections(max);
+ s_connManager.setMaxConcurrentConnections(max);
}
// MUST BE CALLED UNDER A LOCK - right now the LuaLock
});
- typedef std::unordered_map<std::string, boost::variant<bool, std::string, std::map<std::string, std::string>> > webserveropts_t;
+ typedef std::unordered_map<std::string, boost::variant<bool, size_t, std::string, std::map<std::string, std::string>> > webserveropts_t;
luaCtx.writeFunction("setWebserverConfig", [](boost::optional<webserveropts_t> vars) {
setLuaSideEffect();
if (vars->count("statsRequireAuthentication")) {
setWebserverStatsRequireAuthentication(boost::get<bool>(vars->at("statsRequireAuthentication")));
}
+
+ if (vars->count("maxConcurrentConnections")) {
+ setWebserverMaxConcurrentConnections(boost::get<size_t>(vars->at("maxConcurrentConnections")));
+ }
});
luaCtx.writeFunction("controlSocket", [client,configCheck](const std::string& str) {
#include <yahttp/yahttp.hpp>
#include "base64.hh"
+#include "connection-management.hh"
#include "dnsdist.hh"
#include "dnsdist-dynblocks.hh"
#include "dnsdist-healthchecks.hh"
std::string g_apiConfigDirectory;
static const MetricDefinitionStorage s_metricDefinitions;
+static ConcurrentConnectionManager s_connManager(100);
+
+class WebClientConnection
+{
+public:
+ WebClientConnection(const ComboAddress& client, int fd): d_client(client), d_socket(fd)
+ {
+ if (!s_connManager.registerConnection()) {
+ throw std::runtime_error("Too many concurrent web client connections");
+ }
+ }
+ WebClientConnection(WebClientConnection&& rhs): d_client(rhs.d_client), d_socket(std::move(rhs.d_socket))
+ {
+ }
+
+ WebClientConnection(const WebClientConnection&) = delete;
+ WebClientConnection& operator=(const WebClientConnection&) = delete;
+
+ ~WebClientConnection()
+ {
+ if (d_socket.getHandle() != -1) {
+ s_connManager.releaseConnection();
+ }
+ }
+
+ const Socket& getSocket() const
+ {
+ return d_socket;
+ }
+
+ const ComboAddress& getClient() const
+ {
+ return d_client;
+ }
+
+private:
+ ComboAddress d_client;
+ Socket d_socket;
+};
+
const std::map<std::string, MetricDefinition> MetricDefinitionStorage::metrics{
{ "responses", MetricDefinition(PrometheusMetricType::counter, "Number of responses received from backends") },
{ "servfail-responses", MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers received from backends") },
}
}
-static void connectionThread(int sockFD, ComboAddress remote)
+static void connectionThread(WebClientConnection&& conn)
{
setThreadName("dnsdist/webConn");
- vinfolog("Webserver handling connection from %s", remote.toStringWithPort());
-
- Socket sock(sockFD);
- sockFD = -1;
+ vinfolog("Webserver handling connection from %s", conn.getClient().toStringWithPort());
try {
YaHTTP::AsyncRequestLoader yarl;
while (!finished) {
int bytes;
char buf[1024];
- bytes = read(sock.getHandle(), buf, sizeof(buf));
+ bytes = read(conn.getSocket().getHandle(), buf, sizeof(buf));
if (bytes > 0) {
string data = string(buf, bytes);
finished = yarl.feed(data);
else if (!handleAuthorization(req)) {
YaHTTP::strstr_map_t::iterator header = req.headers.find("authorization");
if (header != req.headers.end()) {
- errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, remote.toStringWithPort());
+ errlog("HTTP Request \"%s\" from %s: Web Authentication failed", req.url.path, conn.getClient().toStringWithPort());
}
resp.status = 401;
resp.body = "<h1>Unauthorized</h1>";
std::ostringstream ofs;
ofs << resp;
string done = ofs.str();
- writen2(sock.getHandle(), done.c_str(), done.size());
+ writen2(conn.getSocket().getHandle(), done.c_str(), done.size());
}
catch (const YaHTTP::ParseError& e) {
- vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", remote.toStringWithPort(), e.what());
+ vinfolog("Webserver thread died with parse error exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what());
}
catch (const std::exception& e) {
- errlog("Webserver thread died with exception while processing a request from %s: %s", remote.toStringWithPort(), e.what());
+ errlog("Webserver thread died with exception while processing a request from %s: %s", conn.getClient().toStringWithPort(), e.what());
}
catch (...) {
- errlog("Webserver thread died with exception while processing a request from %s", remote.toStringWithPort());
+ errlog("Webserver thread died with exception while processing a request from %s", conn.getClient().toStringWithPort());
}
}
g_webserverConfig.statsRequireAuthentication = require;
}
+void setWebserverMaxConcurrentConnections(size_t max)
+{
+ s_connManager.setMaxConcurrentConnections(max);
+}
+
void dnsdistWebserverThread(int sock, const ComboAddress& local)
{
setThreadName("dnsdist/webserv");
try {
ComboAddress remote(local);
int fd = SAccept(sock, remote);
+
if (!isClientAllowedByACL(remote)) {
vinfolog("Connection to webserver from client %s is not allowed, closing", remote.toStringWithPort());
close(fd);
continue;
}
+
+ WebClientConnection conn(remote, fd);
vinfolog("Got a connection to the webserver from %s", remote.toStringWithPort());
- std::thread t(connectionThread, fd, remote);
+
+ std::thread t(connectionThread, std::move(conn));
t.detach();
}
catch (const std::exception& e) {
cachecleaner.hh \
capabilities.cc capabilities.hh \
circular_buffer.hh \
+ connection-management.hh \
dns.cc dns.hh \
dnscrypt.cc dnscrypt.hh \
dnsdist-backend.cc \
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <mutex>
+
+class ConcurrentConnectionManager
+{
+public:
+ ConcurrentConnectionManager(size_t max): d_maxConcurrentConnections(max)
+ {
+ }
+
+ void setMaxConcurrentConnections(size_t max)
+ {
+ std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+ d_maxConcurrentConnections = max;
+ }
+
+ bool registerConnection()
+ {
+ std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+ if (d_maxConcurrentConnections == 0 || d_currentConnectionsCount < d_maxConcurrentConnections) {
+ ++d_currentConnectionsCount;
+ return true;
+ }
+ return false;
+ }
+
+ void releaseConnection()
+ {
+ std::lock_guard<decltype(d_concurrentConnectionsLock)> lock(d_concurrentConnectionsLock);
+ --d_currentConnectionsCount;
+ }
+
+private:
+ std::mutex d_concurrentConnectionsLock;
+ size_t d_maxConcurrentConnections{0};
+ size_t d_currentConnectionsCount{0};
+};
void setWebserverACL(const std::string& acl);
void setWebserverCustomHeaders(const boost::optional<std::map<std::string, std::string> > customHeaders);
void setWebserverStatsRequireAuthentication(bool);
+void setWebserverMaxConcurrentConnections(size_t);
void dnsdistWebserverThread(int sock, const ComboAddress& local);
``acl`` optional parameter added.
.. versionchanged:: 1.6.0
- ``statsRequireAuthentication`` optional parameter added.
+ ``statsRequireAuthentication``, ``maxConcurrentConnections`` optional parameters added.
Setup webserver configuration. See :func:`webserver`.
* ``custom_headers={[str]=str,...}``: map of string - Allows setting custom headers and removing the defaults.
* ``acl=newACL``: string - List of IP addresses, as a string, that are allowed to open a connection to the web server. Defaults to "127.0.0.1, ::1".
* ``statsRequireAuthentication``: bool - Whether access to the statistics (/metrics and /jsonstat endpoints) require a valid password or API key. Defaults to true.
+ * ``maxConcurrentConnections``: int - The maximum number of concurrent web connections, or 0 which means an unlimited number. Defaults to 100.
.. function:: registerWebHandler(path, handler)
import base64
import json
import requests
+import socket
+import time
from dnsdisttests import DNSDistTest
class TestAPIBasics(DNSDistTest):
self.assertEquals(r.status_code, 200)
self.assertEquals(r.content, b'It works!')
self.assertEquals(r.headers.get('foo'), "Bar")
+
+class TestWebConcurrentConnectionsL(DNSDistTest):
+
+ _webTimeout = 2.0
+ _webServerPort = 8083
+ _webServerBasicAuthPassword = 'secret'
+ _webServerAPIKey = 'apisecret'
+ _maxConns = 2
+
+ _config_params = ['_testServerPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey', '_maxConns']
+ _config_template = """
+ newServer{address="127.0.0.1:%s"}
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s", maxConcurrentConnections=%d})
+ """
+
+ def testConcurrentConnections(self):
+ """
+ Web: Concurrent connections
+ """
+
+ conns = []
+ # open the maximum number of connections
+ for _ in range(self._maxConns):
+ conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ conn.connect(("127.0.0.1", self._webServerPort))
+ conns.append(conn)
+
+ # we now hold all the slots, let's try to establish a new connection
+ url = 'http://127.0.0.1:' + str(self._webServerPort) + "/"
+ self.assertRaises(requests.exceptions.ConnectionError, requests.get, url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
+
+ # free one slot
+ conns[0].close()
+ conns[0] = None
+ time.sleep(1)
+
+ # this should work
+ r = requests.get(url, auth=('whatever', self._webServerBasicAuthPassword), timeout=self._webTimeout)
+ self.assertTrue(r)
+ self.assertEquals(r.status_code, 200)