]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent leaking the console's socket descriptors 11543/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 15 Apr 2022 08:38:08 +0000 (10:38 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 15 Apr 2022 08:38:08 +0000 (10:38 +0200)
I don't see any case where that could happen but better wrap the
socket descriptors in FDWrapper objects so that we cannot forget to
close them if an exception is raised, for example.

pdns/dnsdist-console.cc

index 61baa69336f7ef23d8d50f9ee14009f844255cae..2058e804d2456396fb518423935e0ae68f4e6caf 100644 (file)
@@ -58,16 +58,14 @@ static ConcurrentConnectionManager s_connManager(100);
 class ConsoleConnection
 {
 public:
-  ConsoleConnection(const ComboAddress& client, int fd): d_client(client), d_fd(fd)
+  ConsoleConnection(const ComboAddress& client, FDWrapper&& fd): d_client(client), d_fd(std::move(fd))
   {
     if (!s_connManager.registerConnection()) {
-      close(fd);
       throw std::runtime_error("Too many concurrent console connections");
     }
   }
-  ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(rhs.d_fd)
+  ConsoleConnection(ConsoleConnection&& rhs): d_client(rhs.d_client), d_fd(std::move(rhs.d_fd))
   {
-    rhs.d_fd = -1;
   }
 
   ConsoleConnection(const ConsoleConnection&) = delete;
@@ -75,15 +73,14 @@ public:
 
   ~ConsoleConnection()
   {
-    if (d_fd != -1) {
-      close(d_fd);
+    if (d_fd.getHandle() != -1) {
       s_connManager.releaseConnection();
     }
   }
 
   int getFD() const
   {
-    return d_fd;
+    return d_fd.getHandle();
   }
 
   const ComboAddress& getClient() const
@@ -93,7 +90,7 @@ public:
 
 private:
   ComboAddress d_client;
-  int d_fd{-1};
+  FDWrapper d_fd;
 };
 
 void setConsoleMaximumConcurrentConnections(size_t max)
@@ -229,28 +226,30 @@ void doClient(ComboAddress server, const std::string& command)
     cout<<"Connecting to "<<server.toStringWithPort()<<endl;
   }
 
-  int fd=socket(server.sin4.sin_family, SOCK_STREAM, 0);
-  if (fd < 0) {
+  auto fd = FDWrapper(socket(server.sin4.sin_family, SOCK_STREAM, 0));
+  if (fd.getHandle() < 0) {
     cerr<<"Unable to connect to "<<server.toStringWithPort()<<endl;
     return;
   }
-  SConnect(fd, server);
-  setTCPNoDelay(fd);
+  SConnect(fd.getHandle(), server);
+  setTCPNoDelay(fd.getHandle());
   SodiumNonce theirs, ours, readingNonce, writingNonce;
   ours.init();
 
-  writen2(fd, (const char*)ours.value, sizeof(ours.value));
-  readn2(fd, (char*)theirs.value, sizeof(theirs.value));
+  writen2(fd.getHandle(), (const char*)ours.value, sizeof(ours.value));
+  readn2(fd.getHandle(), (char*)theirs.value, sizeof(theirs.value));
   readingNonce.merge(ours, theirs);
   writingNonce.merge(theirs, ours);
 
-    /* try sending an empty message, the server should send an empty
+  /* try sending an empty message, the server should send an empty
      one back. If it closes the connection instead, we are probably
      having a key mismatch issue. */
-  auto commandResult = sendMessageToServer(fd, "", readingNonce, writingNonce, false);
+  auto commandResult = sendMessageToServer(fd.getHandle(), "", readingNonce, writingNonce, false);
   if (commandResult == ConsoleCommandResult::ConnectionClosed) {
     cerr<<"The server closed the connection right away, likely indicating a key mismatch. Please check your setKey() directive."<<endl;
-    close(fd);
+    return;
+  }
+  else if (commandResult == ConsoleCommandResult::TooLarge) {
     return;
   }
   else if (commandResult == ConsoleCommandResult::TooLarge) {
@@ -259,9 +258,7 @@ void doClient(ComboAddress server, const std::string& command)
   }
 
   if (!command.empty()) {
-    sendMessageToServer(fd, command, readingNonce, writingNonce, false);
-
-    close(fd);
+    sendMessageToServer(fd.getHandle(), command, readingNonce, writingNonce, false);
     return; 
   }
 
@@ -276,32 +273,35 @@ void doClient(ComboAddress server, const std::string& command)
   }
   ofstream history(histfile, std::ios_base::app);
   string lastline;
-  for(;;) {
+  for (;;) {
     char* sline = readline("> ");
     rl_bind_key('\t',rl_complete);
-    if(!sline)
+    if (!sline) {
       break;
+    }
 
     string line(sline);
-    if(!line.empty() && line != lastline) {
+    if (!line.empty() && line != lastline) {
       add_history(sline);
       history << sline <<endl;
       history.flush();
     }
-    lastline=line;
+    lastline = line;
     free(sline);
     
-    if(line=="quit")
+    if (line == "quit") {
       break;
-    if(line=="help" || line=="?")
-      line="help()";
+    }
+    if (line == "help" || line == "?") {
+      line = "help()";
+    }
 
     /* no need to send an empty line to the server */
     if (line.empty()) {
       continue;
     }
 
-    commandResult = sendMessageToServer(fd, line, readingNonce, writingNonce, true);
+    commandResult = sendMessageToServer(fd.getHandle(), line, readingNonce, writingNonce, true);
     if (commandResult != ConsoleCommandResult::Valid) {
       break;
     }
@@ -309,7 +309,6 @@ void doClient(ComboAddress server, const std::string& command)
 #else
   errlog("Client mode requested but libedit support is not available");
 #endif /* HAVE_LIBEDIT */
-  close(fd);
 }
 
 #ifdef HAVE_LIBEDIT
@@ -794,13 +793,14 @@ static char* my_generator(const char* text, int state)
   string t(text);
   /* to keep it readable, we try to keep only 4 keywords per line
      and to start a new line when the first letter changes */
-  static int s_counter=0;
+  static int s_counter = 0;
   int counter=0;
-  if(!state)
-    s_counter=0;
+  if (!state) {
+    s_counter = 0;
+  }
 
-  for(const auto& keyword : g_consoleKeywords) {
-    if(boost::starts_with(keyword.name, t) && counter++ == s_counter)  {
+  for (const auto& keyword : g_consoleKeywords) {
+    if (boost::starts_with(keyword.name, t) && counter++ == s_counter)  {
       std::string value(keyword.name);
       s_counter++;
       if (keyword.function) {
@@ -818,8 +818,9 @@ static char* my_generator(const char* text, int state)
 char** my_completion( const char * text , int start,  int end)
 {
   char **matches=0;
-  if (start == 0)
+  if (start == 0) {
     matches = rl_completion_matches ((char*)text, &my_generator);
+  }
 
   // skip default filename completion.
   rl_attempted_completion_over = 1;
@@ -832,8 +833,7 @@ char** my_completion( const char * text , int start,  int end)
 
 static void controlClientThread(ConsoleConnection&& conn)
 {
-  try
-  {
+  try {
     setThreadName("dnsdist/conscli");
 
     setTCPNoDelay(conn.getFD());
@@ -845,7 +845,7 @@ static void controlClientThread(ConsoleConnection&& conn)
     readingNonce.merge(ours, theirs);
     writingNonce.merge(theirs, ours);
 
-    for(;;) {
+    for (;;) {
       uint32_t len;
       if (getMsgLen32(conn.getFD(), &len) != ConsoleCommandResult::Valid) {
         break;
@@ -866,7 +866,7 @@ static void controlClientThread(ConsoleConnection&& conn)
 
       string response;
       try {
-        bool withReturn=true;
+        bool withReturn = true;
       retry:;
         try {
           auto lua = g_lua.lock();
@@ -884,39 +884,42 @@ static void controlClientThread(ConsoleConnection&& conn)
               >
             >(withReturn ? ("return "+line) : line);
 
-          if(ret) {
+          if (ret) {
             if (const auto dsValue = boost::get<shared_ptr<DownstreamState>>(&*ret)) {
               if (*dsValue) {
-                response=(*dsValue)->getName()+"\n";
+                response = (*dsValue)->getName()+"\n";
               } else {
-                response="";
+                response = "";
               }
             }
             else if (const auto csValue = boost::get<ClientState*>(&*ret)) {
               if (*csValue) {
-                response=(*csValue)->local.toStringWithPort()+"\n";
+                response = (*csValue)->local.toStringWithPort()+"\n";
               } else {
-                response="";
+                response = "";
               }
             }
             else if (const auto strValue = boost::get<string>(&*ret)) {
-              response=*strValue+"\n";
+              response = *strValue+"\n";
             }
-            else if(const auto um = boost::get<std::unordered_map<string, double> >(&*ret)) {
+            else if (const auto um = boost::get<std::unordered_map<string, double> >(&*ret)) {
               using namespace json11;
               Json::object o;
-              for(const auto& v : *um)
-                o[v.first]=v.second;
+              for(const auto& v : *um) {
+                o[v.first] = v.second;
+              }
               Json out = o;
-              response=out.dump()+"\n";
+              response = out.dump()+"\n";
             }
           }
-          else
-            response=g_outputBuffer;
-          if(!getLuaNoSideEffect())
+          else {
+            response = g_outputBuffer;
+          }
+          if (!getLuaNoSideEffect()) {
             feedConfigDelta(line);
+          }
         }
-        catch(const LuaContext::SyntaxErrorException&) {
+        catch (const LuaContext::SyntaxErrorException&) {
           if(withReturn) {
             withReturn=false;
             goto retry;
@@ -928,23 +931,26 @@ static void controlClientThread(ConsoleConnection&& conn)
         response = "Command returned an object we can't print: " +std::string(e.what()) + "\n";
         // tried to return something we don't understand
       }
-      catch(const LuaContext::ExecutionErrorException& e) {
-        if(!strcmp(e.what(),"invalid key to 'next'"))
+      catch (const LuaContext::ExecutionErrorException& e) {
+        if (!strcmp(e.what(),"invalid key to 'next'")) {
           response = "Error: Parsing function parameters, did you forget parameter name?";
-        else
+        }
+        else {
           response = "Error: " + string(e.what());
+        }
+
         try {
           std::rethrow_if_nested(e);
-        } catch(const std::exception& ne) {
+        } catch (const std::exception& ne) {
           // ne is the exception that was thrown from inside the lambda
           response+= ": " + string(ne.what());
         }
-        catch(const PDNSException& ne) {
+        catch (const PDNSException& ne) {
           // ne is the exception that was thrown from inside the lambda
           response += ": " + string(ne.reason);
         }
       }
-      catch(const LuaContext::SyntaxErrorException& e) {
+      catch (const LuaContext::SyntaxErrorException& e) {
         response = "Error: " + string(e.what()) + ": ";
       }
       response = sodEncryptSym(response, g_consoleKey, writingNonce);
@@ -955,14 +961,14 @@ static void controlClientThread(ConsoleConnection&& conn)
       infolog("Closed control connection from %s", conn.getClient().toStringWithPort());
     }
   }
-  catch (const std::exception& e)
-  {
+  catch (const std::exception& e) {
     errlog("Got an exception in client connection from %s: %s", conn.getClient().toStringWithPort(), e.what());
   }
 }
 
 void controlThread(int fd, ComboAddress local)
 {
+  FDWrapper acceptFD(fd);
   try
   {
     setThreadName("dnsdist/control");
@@ -971,22 +977,21 @@ void controlThread(int fd, ComboAddress local)
     auto localACL = g_consoleACL.getLocal();
     infolog("Accepting control connections on %s", local.toStringWithPort());
 
-    while ((sock = SAccept(fd, client)) >= 0) {
+    while ((sock = SAccept(acceptFD.getHandle(), client)) >= 0) {
 
+      FDWrapper socket(sock);
       if (!sodIsValidKey(g_consoleKey)) {
         vinfolog("Control connection from %s dropped because we don't have a valid key configured, please configure one using setKey()", client.toStringWithPort());
-        close(sock);
         continue;
       }
 
       if (!localACL->match(client)) {
         vinfolog("Control connection from %s dropped because of ACL", client.toStringWithPort());
-        close(sock);
         continue;
       }
 
       try {
-        ConsoleConnection conn(client, sock);
+        ConsoleConnection conn(client, std::move(socket));
         if (g_logConsoleConnections) {
           warnlog("Got control connection from %s", client.toStringWithPort());
         }
@@ -999,9 +1004,7 @@ void controlThread(int fd, ComboAddress local)
       }
     }
   }
-  catch (const std::exception& e)
-  {
-    close(fd);
+  catch (const std::exception& e) {
     errlog("Control thread died: %s", e.what());
   }
 }