]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Limit the number of concurrent console connections
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 21 Jan 2021 12:58:29 +0000 (13:58 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 29 Jan 2021 09:57:38 +0000 (10:57 +0100)
pdns/dnsdist-console.cc
pdns/dnsdist-console.hh
pdns/dnsdist-lua.cc
pdns/dnsdistdist/docs/reference/config.rst
regression-tests.dnsdist/test_Console.py

index c275e97f56547e636d3ccdf355ed04bce2578d74..3d8b02c1f3a063a36a69d53b04b376d368b8e335 100644 (file)
@@ -21,9 +21,9 @@
  */
 
 #include <fstream>
+// we need this to get the home directory of the current user
 #include <pwd.h>
 #include <thread>
-#include <boost/scoped_array.hpp>
 
 #if defined (__OpenBSD__) || defined(__NetBSD__)
 // If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h
@@ -49,6 +49,87 @@ bool g_logConsoleConnections{true};
 bool g_consoleEnabled{false};
 uint32_t g_consoleOutputMsgMaxSize{10000000};
 
+class ConcurrentConnectionManager
+{
+public:
+  static void setMaxConcurrentConnections(size_t max)
+  {
+    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+    s_maxConcurrentConsoleConnections = max;
+  }
+
+  static bool registerConnection()
+  {
+    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+    if (s_maxConcurrentConsoleConnections == 0 || s_currentConnectionsCount < s_maxConcurrentConsoleConnections) {
+      ++s_currentConnectionsCount;
+      return true;
+    }
+    return false;
+  }
+
+  static void releaseConnection()
+  {
+    std::lock_guard<decltype(s_concurrentConsoleConnectionsLock)> lock(s_concurrentConsoleConnectionsLock);
+    --s_currentConnectionsCount;
+  }
+
+private:
+  static std::mutex s_concurrentConsoleConnectionsLock;
+  static size_t s_maxConcurrentConsoleConnections;
+  static size_t s_currentConnectionsCount;
+};
+
+size_t ConcurrentConnectionManager::s_maxConcurrentConsoleConnections{100};
+size_t ConcurrentConnectionManager::s_currentConnectionsCount{0};
+std::mutex ConcurrentConnectionManager::s_concurrentConsoleConnectionsLock;
+
+class ConsoleConnection
+{
+public:
+  ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
+  {
+    if (!ConcurrentConnectionManager::registerConnection()) {
+      close(fd);
+      throw std::runtime_error("Too many concurrent console connections");
+    }
+  }
+  ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(rhs.d_fd)
+  {
+    rhs.d_fd = -1;
+  }
+
+  ConsoleConnection(const ConsoleConnection&) = delete;
+  ConsoleConnection& operator=(const ConsoleConnection&) = delete;
+
+  ~ConsoleConnection()
+  {
+    if (d_fd != -1) {
+      close(d_fd);
+      ConcurrentConnectionManager::releaseConnection();
+    }
+  }
+
+  int getFD() const
+  {
+    return d_fd;
+  }
+
+  const ComboAddress& getClient() const
+  {
+    return d_client;
+  }
+
+private:
+  ComboAddress d_client;
+  int d_fd{-1};
+};
+
+void setConsoleMaximumConcurrentConnections(size_t max)
+{
+  ConcurrentConnectionManager::setMaxConcurrentConnections(max);
+}
+
 // MUST BE CALLED UNDER A LOCK - right now the LuaLock
 static void feedConfigDelta(const std::string& line)
 {
@@ -144,9 +225,9 @@ static bool sendMessageToServer(int fd, const std::string& line, SodiumNonce& re
     return true;
   }
 
-  boost::scoped_array<char> resp(new char[len]);
-  readn2(fd, resp.get(), len);
-  msg.assign(resp.get(), len);
+  msg.clear();
+  msg.resize(len);
+  readn2(fd, msg.data(), len);
   msg = sodDecryptSym(msg, g_consoleKey, readingNonce);
   cout << msg;
   cout.flush();
@@ -161,7 +242,7 @@ void doClient(ComboAddress server, const std::string& command)
     return;
   }
 
-  if(g_verbose) {
+  if (g_verbose) {
     cout<<"Connecting to "<<server.toStringWithPort()<<endl;
   }
 
@@ -525,6 +606,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setConsistentHashingBalancingFactor", true, "factor", "Set the balancing factor for bounded-load consistent hashing" },
   { "setConsoleACL", true, "{netmask, netmask}", "replace the console ACL set with these netmasks" },
   { "setConsoleConnectionsLogging", true, "enabled", "whether to log the opening and closing of console connections" },
+  { "setConsoleMaximumConcurrentConnections", true, "max", "Set the maximum number of concurrent console connections" },
   { "setConsoleOutputMaxMsgSize", true, "messageSize", "set console message maximum size in bytes, default is 10 MB" },
   { "setDefaultBPFFilter", true, "filter", "When used at configuration time, the corresponding BPFFilter will be attached to every bind" },
   { "setDynBlocksAction", true, "action", "set which action is performed when a query is blocked. Only DNSAction.Drop (the default) and DNSAction.Refused are supported" },
@@ -683,38 +765,40 @@ char** my_completion( const char * text , int start,  int end)
 }
 }
 
-static void controlClientThread(int fd, ComboAddress client)
+static void controlClientThread(ConsoleConnection&& conn)
 {
   try
   {
     setThreadName("dnsdist/conscli");
-    setTCPNoDelay(fd);
+
+    setTCPNoDelay(conn.getFD());
+
     SodiumNonce theirs, ours, readingNonce, writingNonce;
     ours.init();
-    readn2(fd, (char*)theirs.value, sizeof(theirs.value));
-    writen2(fd, (char*)ours.value, sizeof(ours.value));
+    readn2(conn.getFD(), (char*)theirs.value, sizeof(theirs.value));
+    writen2(conn.getFD(), (char*)ours.value, sizeof(ours.value));
     readingNonce.merge(ours, theirs);
     writingNonce.merge(theirs, ours);
 
     for(;;) {
       uint32_t len;
-      if(!getMsgLen32(fd, &len))
+      if (!getMsgLen32(conn.getFD(), &len)) {
         break;
+      }
 
       if (len == 0) {
         /* just ACK an empty message
            with an empty response */
-        putMsgLen32(fd, 0);
+        putMsgLen32(conn.getFD(), 0);
         continue;
       }
 
-      boost::scoped_array<char> msg(new char[len]);
-      readn2(fd, msg.get(), len);
-
-      string line(msg.get(), len);
+      std::string line;
+      line.resize(len);
+      readn2(conn.getFD(), line.data(), len);
 
       line = sodDecryptSym(line, g_consoleKey, readingNonce);
-      //    cerr<<"Have decrypted line: "<<line<<endl;
+
       string response;
       try {
         bool withReturn=true;
@@ -799,20 +883,16 @@ static void controlClientThread(int fd, ComboAddress client)
         response = "Error: " + string(e.what()) + ": ";
       }
       response = sodEncryptSym(response, g_consoleKey, writingNonce);
-      putMsgLen32(fd, response.length());
-      writen2(fd, response.c_str(), response.length());
+      putMsgLen32(conn.getFD(), response.length());
+      writen2(conn.getFD(), response.c_str(), response.length());
     }
     if (g_logConsoleConnections) {
-      infolog("Closed control connection from %s", client.toStringWithPort());
+      infolog("Closed control connection from %s", conn.getClient().toStringWithPort());
     }
-    close(fd);
-    fd=-1;
   }
   catch (const std::exception& e)
   {
-    errlog("Got an exception in client connection from %s: %s", client.toStringWithPort(), e.what());
-    if(fd >= 0)
-      close(fd);
+    errlog("Got an exception in client connection from %s: %s", conn.getClient().toStringWithPort(), e.what());
   }
 }
 
@@ -840,18 +920,24 @@ void controlThread(int fd, ComboAddress local)
         continue;
       }
 
-      if (g_logConsoleConnections) {
-        warnlog("Got control connection from %s", client.toStringWithPort());
-      }
+      try {
+        ConsoleConnection conn(client, sock);
+        if (g_logConsoleConnections) {
+          warnlog("Got control connection from %s", client.toStringWithPort());
+        }
 
-      std::thread t(controlClientThread, sock, client);
-      t.detach();
+        std::thread t(controlClientThread, std::move(conn));
+        t.detach();
+      }
+      catch (const std::exception& e) {
+        errlog("Control connection died: %s", e.what());
+      }
     }
   }
   catch (const std::exception& e)
   {
     close(fd);
-    errlog("Control connection died: %s", e.what());
+    errlog("Control thread died: %s", e.what());
   }
 }
 
index 70800644a56098de7508b55d321809c6b91ea41c..5697c8ac8e3d31f7ee88936b58c9a2693a7f1479 100644 (file)
@@ -52,3 +52,5 @@ char** my_completion( const char * text , int start,  int end);
 }
 void controlThread(int fd, ComboAddress local);
 void clearConsoleHistory();
+
+void setConsoleMaximumConcurrentConnections(size_t max);
index a65be53b70f78f3179741d8302c4f9674003c26d..a686c886e3453f5075275ca667d6419cb1fd8a4e 100644 (file)
@@ -1014,6 +1014,11 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
     });
 
+  luaCtx.writeFunction("setConsoleMaximumConcurrentConnections", [](size_t max) {
+      setLuaSideEffect();
+      setConsoleMaximumConcurrentConnections(max);
+  });
+
   luaCtx.writeFunction("clearQueryCounters", []() {
       unsigned int size{0};
       {
index 48f240233ea8ce7b51d6852e8c15b77e1c8a1361..a03ccdc0a9029ffe1491d6d686aab3456942d3bf 100644 (file)
@@ -244,6 +244,14 @@ Control Socket, Console and Webserver
 
   :param bool enabled: Default to true.
 
+.. function:: setConsoleMaximumConcurrentConnections(max)
+
+  .. versionadded:: 1.6.0
+
+  Set the maximum number of concurrent console connections.
+
+  :param int max: The maximum number of concurrent console connections, or 0 which means an unlimited number. Defaults to 100
+
 .. function:: setKey(key)
 
   Use ``key`` as shared secret between the client and the server
index 952cc2afa6a151a1b232512ee72a2764537cfc10..ee3796125127a66ce73453f1df2585dac66fcda4 100644 (file)
@@ -1,7 +1,8 @@
 #!/usr/bin/env python
 import base64
 import dns
-from socket import error as SocketError
+import socket
+import time
 from dnsdisttests import DNSDistTest
 
 class TestConsoleAllowed(DNSDistTest):
@@ -40,7 +41,7 @@ class TestConsoleNotAllowed(DNSDistTest):
         """
         Console: Not allowed by the ACL
         """
-        self.assertRaises(SocketError, self.sendConsoleCommand, 'showVersion()')
+        self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
 
 class TestConsoleNoKey(DNSDistTest):
 
@@ -57,4 +58,41 @@ class TestConsoleNoKey(DNSDistTest):
         """
         Console: No key, the connection should not be allowed
         """
-        self.assertRaises(SocketError, self.sendConsoleCommand, 'showVersion()')
+        self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
+
+class TestConsoleConcurrentConnections(DNSDistTest):
+
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _maxConns = 2
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_maxConns']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+    newServer{address="127.0.0.1:%d"}
+    setConsoleMaximumConcurrentConnections(%d)
+    """
+
+    def testConsoleConnectionsLimit(self):
+        """
+        Console: Check the maximum number of connections
+        """
+        conns = []
+        # open the maximum number of connections
+        for _ in range(self._maxConns):
+            conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            conn.connect(("127.0.0.1", self._consolePort))
+            conns.append(conn)
+
+        # we now hold all the slots, let's try to establish a new connection
+        self.assertRaises(socket.error, self.sendConsoleCommand, 'showVersion()')
+
+        # free one slot
+        conns[0].close()
+        conns[0] = None
+        time.sleep(1)
+
+        # this should work
+        version = self.sendConsoleCommand('showVersion()')
+        self.assertTrue(version.startswith('dnsdist '))