]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist-console.cc
Merge pull request #11591 from rgacogne/ddist-mac-netlink
[thirdparty/pdns.git] / pdns / dnsdist-console.cc
index d112c2ade1a72ac95b5c715cd1a1f7caf6b4212f..73beb01aecbe94c48090e3dd957591a76ad690da 100644 (file)
@@ -58,16 +58,14 @@ static ConcurrentConnectionManager s_connManager(100);
 class ConsoleConnection
 {
 public:
-  ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
+  ConsoleConnection(const ComboAddress& client, FDWrapper&& fd): d_client(client), d_fd(std::move(fd))
   {
     if (!s_connManager.registerConnection()) {
-      close(fd);
       throw std::runtime_error("Too many concurrent console connections");
     }
   }
-  ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(rhs.d_fd)
+  ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(std::move(rhs.d_fd))
   {
-    rhs.d_fd = -1;
   }
 
   ConsoleConnection(const ConsoleConnection&) = delete;
@@ -75,15 +73,14 @@ public:
 
   ~ConsoleConnection()
   {
-    if (d_fd != -1) {
-      close(d_fd);
+    if (d_fd.getHandle() != -1) {
       s_connManager.releaseConnection();
     }
   }
 
   int getFD() const
   {
-    return d_fd;
+    return d_fd.getHandle();
   }
 
   const ComboAddress& getClient() const
@@ -93,7 +90,7 @@ public:
 
 private:
   ComboAddress d_client;
-  int d_fd{-1};
+  FDWrapper d_fd;
 };
 
 void setConsoleMaximumConcurrentConnections(size_t max)
@@ -133,26 +130,31 @@ static string historyFile(const bool &ignoreHOME = false)
 }
 #endif /* HAVE_LIBEDIT */
 
-static bool getMsgLen32(int fd, uint32_t* len)
+enum class ConsoleCommandResult : uint8_t {
+  Valid = 0,
+  ConnectionClosed,
+  TooLarge
+};
+
+static ConsoleCommandResult getMsgLen32(int fd, uint32_t* len)
 {
-  try
-  {
+  try {
     uint32_t raw;
-    size_t ret = readn2(fd, &raw, sizeof raw);
+    size_t ret = readn2(fd, &raw, sizeof(raw));
 
     if (ret != sizeof raw) {
-      return false;
+      return ConsoleCommandResult::ConnectionClosed;
     }
 
     *len = ntohl(raw);
     if (*len > g_consoleOutputMsgMaxSize) {
-      return false;
+      return ConsoleCommandResult::TooLarge;
     }
 
-    return true;
+    return ConsoleCommandResult::Valid;
   }
-  catch(...) {
-    return false;
+  catch (...) {
+    return ConsoleCommandResult::ConnectionClosed;
   }
 }
 
@@ -169,13 +171,13 @@ static bool putMsgLen32(int fd, uint32_t len)
   }
 }
 
-static bool sendMessageToServer(int fd, const std::string& line, SodiumNonce& readingNonce, SodiumNonce& writingNonce, const bool outputEmptyLine)
+static ConsoleCommandResult sendMessageToServer(int fd, const std::string& line, SodiumNonce& readingNonce, SodiumNonce& writingNonce, const bool outputEmptyLine)
 {
   string msg = sodEncryptSym(line, g_consoleKey, writingNonce);
   const auto msgLen = msg.length();
   if (msgLen > std::numeric_limits<uint32_t>::max()) {
-    cout << "Encrypted message is too long to be sent to the server, "<< std::to_string(msgLen) << " > " << std::numeric_limits<uint32_t>::max() << endl;
-    return true;
+    cerr << "Encrypted message is too long to be sent to the server, "<< std::to_string(msgLen) << " > " << std::numeric_limits<uint32_t>::max() << endl;
+    return ConsoleCommandResult::TooLarge;
   }
 
   putMsgLen32(fd, static_cast<uint32_t>(msgLen));
@@ -185,9 +187,14 @@ static bool sendMessageToServer(int fd, const std::string& line, SodiumNonce& re
   }
 
   uint32_t len;
-  if(!getMsgLen32(fd, &len)) {
+  auto commandResult = getMsgLen32(fd, &len);
+  if (commandResult == ConsoleCommandResult::ConnectionClosed) {
     cout << "Connection closed by the server." << endl;
-    return false;
+    return commandResult;
+  }
+  else if (commandResult == ConsoleCommandResult::TooLarge) {
+    cerr << "Received a console message whose length (" << len << ") is exceeding the allowed one (" << g_consoleOutputMsgMaxSize << "), closing that connection" << endl;
+    return commandResult;
   }
 
   if (len == 0) {
@@ -195,7 +202,7 @@ static bool sendMessageToServer(int fd, const std::string& line, SodiumNonce& re
       cout << endl;
     }
 
-    return true;
+    return ConsoleCommandResult::Valid;
   }
 
   msg.clear();
@@ -205,7 +212,7 @@ static bool sendMessageToServer(int fd, const std::string& line, SodiumNonce& re
   cout << msg;
   cout.flush();
 
-  return true;
+  return ConsoleCommandResult::Valid;
 }
 
 void doClient(ComboAddress server, const std::string& command)
@@ -219,34 +226,35 @@ void doClient(ComboAddress server, const std::string& command)
     cout<<"Connecting to "<<server.toStringWithPort()<<endl;
   }
 
-  int fd=socket(server.sin4.sin_family, SOCK_STREAM, 0);
-  if (fd < 0) {
+  auto fd = FDWrapper(socket(server.sin4.sin_family, SOCK_STREAM, 0));
+  if (fd.getHandle() < 0) {
     cerr<<"Unable to connect to "<<server.toStringWithPort()<<endl;
     return;
   }
-  SConnect(fd, server);
-  setTCPNoDelay(fd);
+  SConnect(fd.getHandle(), server);
+  setTCPNoDelay(fd.getHandle());
   SodiumNonce theirs, ours, readingNonce, writingNonce;
   ours.init();
 
-  writen2(fd, (const char*)ours.value, sizeof(ours.value));
-  readn2(fd, (char*)theirs.value, sizeof(theirs.value));
+  writen2(fd.getHandle(), (const char*)ours.value, sizeof(ours.value));
+  readn2(fd.getHandle(), (char*)theirs.value, sizeof(theirs.value));
   readingNonce.merge(ours, theirs);
   writingNonce.merge(theirs, ours);
 
   /* try sending an empty message, the server should send an empty
      one back. If it closes the connection instead, we are probably
      having a key mismatch issue. */
-  if (!sendMessageToServer(fd, "", readingNonce, writingNonce, false)) {
+  auto commandResult = sendMessageToServer(fd.getHandle(), "", readingNonce, writingNonce, false);
+  if (commandResult == ConsoleCommandResult::ConnectionClosed) {
     cerr<<"The server closed the connection right away, likely indicating a key mismatch. Please check your setKey() directive."<<endl;
-    close(fd);
+    return;
+  }
+  else if (commandResult == ConsoleCommandResult::TooLarge) {
     return;
   }
 
   if (!command.empty()) {
-    sendMessageToServer(fd, command, readingNonce, writingNonce, false);
-
-    close(fd);
+    sendMessageToServer(fd.getHandle(), command, readingNonce, writingNonce, false);
     return; 
   }
 
@@ -261,38 +269,42 @@ void doClient(ComboAddress server, const std::string& command)
   }
   ofstream history(histfile, std::ios_base::app);
   string lastline;
-  for(;;) {
+  for (;;) {
     char* sline = readline("> ");
     rl_bind_key('\t',rl_complete);
-    if(!sline)
+    if (!sline) {
       break;
+    }
 
     string line(sline);
-    if(!line.empty() && line != lastline) {
+    if (!line.empty() && line != lastline) {
       add_history(sline);
       history << sline <<endl;
       history.flush();
     }
-    lastline=line;
+    lastline = line;
     free(sline);
     
-    if(line=="quit")
+    if (line == "quit") {
       break;
-    if(line=="help" || line=="?")
-      line="help()";
+    }
+    if (line == "help" || line == "?") {
+      line = "help()";
+    }
 
     /* no need to send an empty line to the server */
-    if(line.empty())
+    if (line.empty()) {
       continue;
+    }
 
-    if (!sendMessageToServer(fd, line, readingNonce, writingNonce, true)) {
+    commandResult = sendMessageToServer(fd.getHandle(), line, readingNonce, writingNonce, true);
+    if (commandResult != ConsoleCommandResult::Valid) {
       break;
     }
   }
 #else
   errlog("Client mode requested but libedit support is not available");
 #endif /* HAVE_LIBEDIT */
-  close(fd);
 }
 
 #ifdef HAVE_LIBEDIT
@@ -527,6 +539,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "getTLSContext", true, "n", "returns the TLS context with index n" },
   { "getTLSFrontend", true, "n", "returns the TLS frontend with index n" },
   { "getTLSFrontendCount", true, "", "returns the number of DoT listeners" },
+  { "getVerbose", true, "", "get whether log messages at the verbose level will be logged" },
   { "grepq", true, "Netmask|DNS Name|100ms|{\"::1\", \"powerdns.com\", \"100ms\"} [, n]", "shows the last n queries and responses matching the specified client address or range (Netmask), or the specified DNS Name, or slower than 100ms" },
   { "hashPassword", true, "password [, workFactor]", "Returns a hashed and salted version of the supplied password, usable with 'setWebserverConfig()'"},
   { "HTTPHeaderRule", true, "name, regex", "matches DoH queries with a HTTP header 'name' whose content matches the regular expression 'regex'"},
@@ -694,6 +707,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setUDPMultipleMessagesVectorSize", true, "n", "set the size of the vector passed to recvmmsg() to receive UDP messages. Default to 1 which means that the feature is disabled and recvmsg() is used instead" },
   { "setUDPSocketBufferSizes", true, "recv, send", "Set the size of the receive (SO_RCVBUF) and send (SO_SNDBUF) buffers for incoming UDP sockets" },
   { "setUDPTimeout", true, "n", "set the maximum time dnsdist will wait for a response from a backend over UDP, in seconds" },
+  { "setVerbose", true, "bool", "set whether log messages at the verbose level will be logged" },
   { "setVerboseHealthChecks", true, "bool", "set whether health check errors will be logged" },
   { "setWebserverConfig", true, "[{password=string, apiKey=string, customHeaders, statsRequireAuthentication}]", "Updates webserver configuration" },
   { "setWeightedBalancingFactor", true, "factor", "Set the balancing factor for bounded-load weighted policies (whashed, wrandom)" },
@@ -778,13 +792,14 @@ static char* my_generator(const char* text, int state)
   string t(text);
   /* to keep it readable, we try to keep only 4 keywords per line
      and to start a new line when the first letter changes */
-  static int s_counter=0;
+  static int s_counter = 0;
   int counter=0;
-  if(!state)
-    s_counter=0;
+  if (!state) {
+    s_counter = 0;
+  }
 
-  for(const auto& keyword : g_consoleKeywords) {
-    if(boost::starts_with(keyword.name, t) && counter++ == s_counter)  {
+  for (const auto& keyword : g_consoleKeywords) {
+    if (boost::starts_with(keyword.name, t) && counter++ == s_counter)  {
       std::string value(keyword.name);
       s_counter++;
       if (keyword.function) {
@@ -802,8 +817,9 @@ static char* my_generator(const char* text, int state)
 char** my_completion( const char * text , int start,  int end)
 {
   char **matches=0;
-  if (start == 0)
+  if (start == 0) {
     matches = rl_completion_matches ((char*)text, &my_generator);
+  }
 
   // skip default filename completion.
   rl_attempted_completion_over = 1;
@@ -816,8 +832,7 @@ char** my_completion( const char * text , int start,  int end)
 
 static void controlClientThread(ConsoleConnection&& conn)
 {
-  try
-  {
+  try {
     setThreadName("dnsdist/conscli");
 
     setTCPNoDelay(conn.getFD());
@@ -829,9 +844,9 @@ static void controlClientThread(ConsoleConnection&& conn)
     readingNonce.merge(ours, theirs);
     writingNonce.merge(theirs, ours);
 
-    for(;;) {
+    for (;;) {
       uint32_t len;
-      if (!getMsgLen32(conn.getFD(), &len)) {
+      if (getMsgLen32(conn.getFD(), &len) != ConsoleCommandResult::Valid) {
         break;
       }
 
@@ -850,7 +865,7 @@ static void controlClientThread(ConsoleConnection&& conn)
 
       string response;
       try {
-        bool withReturn=true;
+        bool withReturn = true;
       retry:;
         try {
           auto lua = g_lua.lock();
@@ -868,39 +883,42 @@ static void controlClientThread(ConsoleConnection&& conn)
               >
             >(withReturn ? ("return "+line) : line);
 
-          if(ret) {
+          if (ret) {
             if (const auto dsValue = boost::get<shared_ptr<DownstreamState>>(&*ret)) {
               if (*dsValue) {
-                response=(*dsValue)->getName()+"\n";
+                response = (*dsValue)->getName()+"\n";
               } else {
-                response="";
+                response = "";
               }
             }
             else if (const auto csValue = boost::get<ClientState*>(&*ret)) {
               if (*csValue) {
-                response=(*csValue)->local.toStringWithPort()+"\n";
+                response = (*csValue)->local.toStringWithPort()+"\n";
               } else {
-                response="";
+                response = "";
               }
             }
             else if (const auto strValue = boost::get<string>(&*ret)) {
-              response=*strValue+"\n";
+              response = *strValue+"\n";
             }
-            else if(const auto um = boost::get<std::unordered_map<string, double> >(&*ret)) {
+            else if (const auto um = boost::get<std::unordered_map<string, double> >(&*ret)) {
               using namespace json11;
               Json::object o;
-              for(const auto& v : *um)
-                o[v.first]=v.second;
+              for(const auto& v : *um) {
+                o[v.first] = v.second;
+              }
               Json out = o;
-              response=out.dump()+"\n";
+              response = out.dump()+"\n";
             }
           }
-          else
-            response=g_outputBuffer;
-          if(!getLuaNoSideEffect())
+          else {
+            response = g_outputBuffer;
+          }
+          if (!getLuaNoSideEffect()) {
             feedConfigDelta(line);
+          }
         }
-        catch(const LuaContext::SyntaxErrorException&) {
+        catch (const LuaContext::SyntaxErrorException&) {
           if(withReturn) {
             withReturn=false;
             goto retry;
@@ -912,23 +930,26 @@ static void controlClientThread(ConsoleConnection&& conn)
         response = "Command returned an object we can't print: " +std::string(e.what()) + "\n";
         // tried to return something we don't understand
       }
-      catch(const LuaContext::ExecutionErrorException& e) {
-        if(!strcmp(e.what(),"invalid key to 'next'"))
+      catch (const LuaContext::ExecutionErrorException& e) {
+        if (!strcmp(e.what(),"invalid key to 'next'")) {
           response = "Error: Parsing function parameters, did you forget parameter name?";
-        else
+        }
+        else {
           response = "Error: " + string(e.what());
+        }
+
         try {
           std::rethrow_if_nested(e);
-        } catch(const std::exception& ne) {
+        } catch (const std::exception& ne) {
           // ne is the exception that was thrown from inside the lambda
           response+= ": " + string(ne.what());
         }
-        catch(const PDNSException& ne) {
+        catch (const PDNSException& ne) {
           // ne is the exception that was thrown from inside the lambda
           response += ": " + string(ne.reason);
         }
       }
-      catch(const LuaContext::SyntaxErrorException& e) {
+      catch (const LuaContext::SyntaxErrorException& e) {
         response = "Error: " + string(e.what()) + ": ";
       }
       response = sodEncryptSym(response, g_consoleKey, writingNonce);
@@ -939,14 +960,14 @@ static void controlClientThread(ConsoleConnection&& conn)
       infolog("Closed control connection from %s", conn.getClient().toStringWithPort());
     }
   }
-  catch (const std::exception& e)
-  {
+  catch (const std::exception& e) {
     errlog("Got an exception in client connection from %s: %s", conn.getClient().toStringWithPort(), e.what());
   }
 }
 
 void controlThread(int fd, ComboAddress local)
 {
+  FDWrapper acceptFD(fd);
   try
   {
     setThreadName("dnsdist/control");
@@ -955,22 +976,21 @@ void controlThread(int fd, ComboAddress local)
     auto localACL = g_consoleACL.getLocal();
     infolog("Accepting control connections on %s", local.toStringWithPort());
 
-    while ((sock = SAccept(fd, client)) >= 0) {
+    while ((sock = SAccept(acceptFD.getHandle(), client)) >= 0) {
 
+      FDWrapper socket(sock);
       if (!sodIsValidKey(g_consoleKey)) {
         vinfolog("Control connection from %s dropped because we don't have a valid key configured, please configure one using setKey()", client.toStringWithPort());
-        close(sock);
         continue;
       }
 
       if (!localACL->match(client)) {
         vinfolog("Control connection from %s dropped because of ACL", client.toStringWithPort());
-        close(sock);
         continue;
       }
 
       try {
-        ConsoleConnection conn(client, sock);
+        ConsoleConnection conn(client, std::move(socket));
         if (g_logConsoleConnections) {
           warnlog("Got control connection from %s", client.toStringWithPort());
         }
@@ -983,9 +1003,7 @@ void controlThread(int fd, ComboAddress local)
       }
     }
   }
-  catch (const std::exception& e)
-  {
-    close(fd);
+  catch (const std::exception& e) {
     errlog("Control thread died: %s", e.what());
   }
 }