*/
#include <fstream>
+// we need this to get the home directory of the current user
#include <pwd.h>
#include <thread>
-#include <boost/scoped_array.hpp>
#if defined (__OpenBSD__) || defined(__NetBSD__)
// If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h
bool g_consoleEnabled{false};
uint32_t g_consoleOutputMsgMaxSize{10000000};
+class ConcurrentConnectionManager
+{
+public:
+ static void setMaxConcurrentConnections(size_t max)
+ {
+ std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+ s_maxConcurrentConsoleConnections = max;
+ }
+
+ static bool registerConnection()
+ {
+ std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+ if (s_maxConcurrentConsoleConnections == 0 || s_currentConnectionsCount < s_maxConcurrentConsoleConnections) {
+ ++s_currentConnectionsCount;
+ return true;
+ }
+ return false;
+ }
+
+ static void releaseConnection()
+ {
+ std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+ --s_currentConnectionsCount;
+ }
+
+private:
+ static std::mutex s_concurrentConsoleConnectionsLock;
+ static size_t s_maxConcurrentConsoleConnections;
+ static size_t s_currentConnectionsCount;
+};
+
+size_t ConcurrentConnectionManager::s_maxConcurrentConsoleConnections{100};
+size_t ConcurrentConnectionManager::s_currentConnectionsCount{0};
+std::mutex ConcurrentConnectionManager::s_concurrentConsoleConnectionsLock;
+
+class ConsoleConnection
+{
+public:
+ ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
+ {
+ if (!ConcurrentConnectionManager::registerConnection()) {
+ close(fd);
+ throw std::runtime_error("Too many concurrent console connections");
+ }
+ }
+ ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(rhs.d_fd)
+ {
+ rhs.d_fd = -1;
+ }
+
+ ConsoleConnection(const ConsoleConnection&) = delete;
+ ConsoleConnection& operator=(const ConsoleConnection&) = delete;
+
+ ~ConsoleConnection()
+ {
+ if (d_fd != -1) {
+ close(d_fd);
+ ConcurrentConnectionManager::releaseConnection();
+ }
+ }
+
+ int getFD() const
+ {
+ return d_fd;
+ }
+
+ const ComboAddress& getClient() const
+ {
+ return d_client;
+ }
+
+private:
+ ComboAddress d_client;
+ int d_fd{-1};
+};
+
+void setConsoleMaximumConcurrentConnections(size_t max)
+{
+ ConcurrentConnectionManager::setMaxConcurrentConnections(max);
+}
+
// MUST BE CALLED UNDER A LOCK - right now the LuaLock
static void feedConfigDelta(const std::string& line)
{
return true;
}
- boost::scoped_array<char> resp(new char[len]);
- readn2(fd, resp.get(), len);
- msg.assign(resp.get(), len);
+ msg.clear();
+ msg.resize(len);
+ readn2(fd, msg.data(), len);
msg = sodDecryptSym(msg, g_consoleKey, readingNonce);
cout << msg;
cout.flush();
return;
}
- if(g_verbose) {
+ if (g_verbose) {
cout<<"Connecting to "<<server.toStringWithPort()<<endl;
}
{ "setConsistentHashingBalancingFactor", true, "factor", "Set the balancing factor for bounded-load consistent hashing" },
{ "setConsoleACL", true, "{netmask, netmask}", "replace the console ACL set with these netmasks" },
{ "setConsoleConnectionsLogging", true, "enabled", "whether to log the opening and closing of console connections" },
+ { "setConsoleMaximumConcurrentConnections", true, "max", "Set the maximum number of concurrent console connections" },
{ "setConsoleOutputMaxMsgSize", true, "messageSize", "set console message maximum size in bytes, default is 10 MB" },
{ "setDefaultBPFFilter", true, "filter", "When used at configuration time, the corresponding BPFFilter will be attached to every bind" },
{ "setDynBlocksAction", true, "action", "set which action is performed when a query is blocked. Only DNSAction.Drop (the default) and DNSAction.Refused are supported" },
}
}
-static void controlClientThread(int fd, ComboAddress client)
+static void controlClientThread(ConsoleConnection&& conn)
{
try
{
setThreadName("dnsdist/conscli");
- setTCPNoDelay(fd);
+
+ setTCPNoDelay(conn.getFD());
+
SodiumNonce theirs, ours, readingNonce, writingNonce;
ours.init();
- readn2(fd, (char*)theirs.value, sizeof(theirs.value));
- writen2(fd, (char*)ours.value, sizeof(ours.value));
+ readn2(conn.getFD(), (char*)theirs.value, sizeof(theirs.value));
+ writen2(conn.getFD(), (char*)ours.value, sizeof(ours.value));
readingNonce.merge(ours, theirs);
writingNonce.merge(theirs, ours);
for(;;) {
uint32_t len;
- if(!getMsgLen32(fd, &len))
+ if (!getMsgLen32(conn.getFD(), &len)) {
break;
+ }
if (len == 0) {
/* just ACK an empty message
with an empty response */
- putMsgLen32(fd, 0);
+ putMsgLen32(conn.getFD(), 0);
continue;
}
- boost::scoped_array<char> msg(new char[len]);
- readn2(fd, msg.get(), len);
-
- string line(msg.get(), len);
+ std::string line;
+ line.resize(len);
+ readn2(conn.getFD(), line.data(), len);
line = sodDecryptSym(line, g_consoleKey, readingNonce);
- // cerr<<"Have decrypted line: "<<line<<endl;
+
string response;
try {
bool withReturn=true;
response = "Error: " + string(e.what()) + ": ";
}
response = sodEncryptSym(response, g_consoleKey, writingNonce);
- putMsgLen32(fd, response.length());
- writen2(fd, response.c_str(), response.length());
+ putMsgLen32(conn.getFD(), response.length());
+ writen2(conn.getFD(), response.c_str(), response.length());
}
if (g_logConsoleConnections) {
- infolog("Closed control connection from %s", client.toStringWithPort());
+ infolog("Closed control connection from %s", conn.getClient().toStringWithPort());
}
- close(fd);
- fd=-1;
}
catch (const std::exception& e)
{
- errlog("Got an exception in client connection from %s: %s", client.toStringWithPort(), e.what());
- if(fd >= 0)
- close(fd);
+ errlog("Got an exception in client connection from %s: %s", conn.getClient().toStringWithPort(), e.what());
}
}
continue;
}
- if (g_logConsoleConnections) {
- warnlog("Got control connection from %s", client.toStringWithPort());
- }
+ try {
+ ConsoleConnection conn(client, sock);
+ if (g_logConsoleConnections) {
+ warnlog("Got control connection from %s", client.toStringWithPort());
+ }
- std::thread t(controlClientThread, sock, client);
- t.detach();
+ std::thread t(controlClientThread, std::move(conn));
+ t.detach();
+ }
+ catch (const std::exception& e) {
+ errlog("Control connection died: %s", e.what());
+ }
}
}
catch (const std::exception& e)
{
close(fd);
- errlog("Control connection died: %s", e.what());
+ errlog("Control thread died: %s", e.what());
}
}
#!/usr/bin/env python
import base64
import dns
-from socket import error as SocketError
+import socket
+import time
from dnsdisttests import DNSDistTest
class TestConsoleAllowed(DNSDistTest):
"""
Console: Not allowed by the ACL
"""
- self.assertRaises(SocketError, self.sendConsoleCommand, 'showVersion()')
+ self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
class TestConsoleNoKey(DNSDistTest):
"""
Console: No key, the connection should not be allowed
"""
- self.assertRaises(SocketError, self.sendConsoleCommand, 'showVersion()')
+ self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
+
+class TestConsoleConcurrentConnections(DNSDistTest):
+
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+ _maxConns = 2
+
+ _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_maxConns']
+ _config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%s")
+ newServer{address="127.0.0.1:%d"}
+ setConsoleMaximumConcurrentConnections(%d)
+ """
+
+ def testConsoleConnectionsLimit(self):
+ """
+ Console: Check the maximum number of connections
+ """
+ conns = []
+ # open the maximum number of connections
+ for _ in range(self._maxConns):
+ conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ conn.connect(("127.0.0.1", self._consolePort))
+ conns.append(conn)
+
+ # we now hold all the slots, let's try to establish a new connection
+ self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
+
+ # free one slot
+ conns[0].close()
+ conns[0] = None
+ time.sleep(1)
+
+ # this should work
+ version = self.sendConsoleCommand('showVersion()')
+ self.assertTrue(version.startswith('dnsdist '))