]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix a small race in the NetworkListener
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Dec 2023 14:08:37 +0000 (15:08 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 8 Dec 2023 07:44:56 +0000 (08:44 +0100)
The main thread needs to be able to access the data even if the
NetworkListener object has been destroyed first, which usually only
happens when DNSdist is exiting, but could also happen earlier if
the Lua handle is garbage collected.

pdns/dnsdistdist/dnsdist-lua-network.cc
pdns/dnsdistdist/dnsdist-lua-network.hh

index ec477255ad417b5efc840bb0a39e3e51e08abac4..5a2e3c2f0019b5eb732308d5186be06e9e2bd726 100644 (file)
 
 namespace dnsdist
 {
-NetworkListener::NetworkListener() :
+NetworkListener::ListenerData::ListenerData() :
   d_mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(10)))
 {
 }
 
+NetworkListener::NetworkListener() :
+  d_data(std::make_shared<ListenerData>())
+{
+}
+
+NetworkListener::~NetworkListener()
+{
+  d_data->d_exiting = true;
+}
+
 void NetworkListener::readCB(int desc, FDMultiplexer::funcparam_t& param)
 {
   auto cbData = boost::any_cast<std::shared_ptr<NetworkListener::CBData>>(param);
@@ -74,7 +84,7 @@ void NetworkListener::readCB(int desc, FDMultiplexer::funcparam_t& param)
 
 bool NetworkListener::addUnixListeningEndpoint(const std::string& path, NetworkListener::EndpointID id, NetworkListener::NetworkDatagramCB cb)
 {
-  if (d_running == true) {
+  if (d_data->d_running == true) {
     throw std::runtime_error("NetworkListener should not be altered at runtime");
   }
 
@@ -114,36 +124,48 @@ bool NetworkListener::addUnixListeningEndpoint(const std::string& path, NetworkL
   auto cbData = std::make_shared<CBData>();
   cbData->d_endpoint = id;
   cbData->d_cb = std::move(cb);
-  d_mplexer->addReadFD(sock.getHandle(), readCB, cbData);
+  d_data->d_mplexer->addReadFD(sock.getHandle(), readCB, cbData);
 
-  d_sockets.insert({path, std::move(sock)});
+  d_data->d_sockets.insert({path, std::move(sock)});
   return true;
 }
 
-void NetworkListener::runOnce(struct timeval& now, uint32_t timeout)
+void NetworkListener::runOnce(ListenerData& data, timeval& now, uint32_t timeout)
 {
-  d_running = true;
-  if (d_sockets.empty()) {
+  if (data.d_exiting) {
+    return;
+  }
+
+  data.d_running = true;
+  if (data.d_sockets.empty()) {
     throw runtime_error("NetworkListener started with no sockets");
   }
 
-  d_mplexer->run(&now, timeout);
+  data.d_mplexer->run(&now, timeout);
+}
+
+void NetworkListener::runOnce(timeval& now, uint32_t timeout)
+{
+  runOnce(*d_data, now, timeout);
 }
 
-void NetworkListener::mainThread()
+void NetworkListener::mainThread(std::shared_ptr<ListenerData>& dataArg)
 {
+  /* take our own copy of the shared_ptr so it's still alive if the NetworkListener object
+     gets destroyed while we are still running */
+  auto data = dataArg;
   setThreadName("dnsdist/lua-net");
-  struct timeval now;
+  timeval now{};
 
-  while (true) {
-    runOnce(now, -1);
+  while (!data->d_exiting) {
+    runOnce(*data, now, -1);
   }
 }
 
 void NetworkListener::start()
 {
   std::thread main = std::thread([this] {
-    mainThread();
+    mainThread(d_data);
   });
   main.detach();
 }
index a63efd479777b1528b7eec856da40d4fbb882357..3d12f63acdf27bfaa232394b76e3cbf065bedc4e 100644 (file)
@@ -34,16 +34,32 @@ class NetworkListener
 {
 public:
   NetworkListener();
+  NetworkListener(const NetworkListener&) = delete;
+  NetworkListener(NetworkListener&&) = delete;
+  NetworkListener& operator=(const NetworkListener&) = delete;
+  NetworkListener& operator=(NetworkListener&&) = delete;
+  ~NetworkListener();
 
   using EndpointID = uint16_t;
   using NetworkDatagramCB = std::function<void(EndpointID endpoint, std::string&& dgram, const std::string& from)>;
   bool addUnixListeningEndpoint(const std::string& path, EndpointID id, NetworkDatagramCB cb);
   void start();
-  void runOnce(struct timeval& now, uint32_t timeout);
+  void runOnce(timeval& now, uint32_t timeout);
 
 private:
+  struct ListenerData
+  {
+    ListenerData();
+
+    std::unique_ptr<FDMultiplexer> d_mplexer;
+    std::unordered_map<std::string, Socket> d_sockets;
+    std::atomic<bool> d_running{false};
+    std::atomic<bool> d_exiting{false};
+  };
+
   static void readCB(int desc, FDMultiplexer::funcparam_t& param);
-  void mainThread();
+  static void mainThread(std::shared_ptr<ListenerData>& data);
+  static void runOnce(ListenerData& data, timeval& now, uint32_t timeout);
 
   struct CBData
   {
@@ -51,9 +67,7 @@ private:
     EndpointID d_endpoint;
   };
 
-  std::unique_ptr<FDMultiplexer> d_mplexer;
-  std::unordered_map<std::string, Socket> d_sockets;
-  std::atomic<bool> d_running{false};
+  std::shared_ptr<ListenerData> d_data;
 };
 
 class NetworkEndpoint