]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Wrap TCP connection objects in smart pointers 7108/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 26 Oct 2018 08:06:13 +0000 (10:06 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 26 Oct 2018 08:06:13 +0000 (10:06 +0200)
pdns/dnsdist-tcp.cc

index 9eba8619142117f6bc936ab58710d24b78ba6a21..eace3fd1b9759325ee7737696792f4cb6c419b52 100644 (file)
@@ -92,9 +92,34 @@ static int setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstre
 
 struct ConnectionInfo
 {
-  int fd;
+  ConnectionInfo(): cs(nullptr), fd(-1)
+  {
+  }
+
+  ConnectionInfo(const ConnectionInfo& rhs) = delete;
+  ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
+
+  ConnectionInfo& operator=(ConnectionInfo&& rhs)
+  {
+    remote = rhs.remote;
+    cs = rhs.cs;
+    rhs.cs = nullptr;
+    fd = rhs.fd;
+    rhs.fd = -1;
+    return *this;
+  }
+
+  ~ConnectionInfo()
+  {
+    if (fd != -1) {
+      close(fd);
+      fd = -1;
+    }
+  }
+
   ComboAddress remote;
-  ClientState* cs;
+  ClientState* cs{nullptr};
+  int fd{-1};
 };
 
 uint64_t g_maxTCPQueuedConnections{1000};
@@ -264,7 +289,7 @@ void* tcpClientThread(int pipefd)
     }
 
     g_tcpclientthreads->decrementQueuedCount();
-    ci=*citmp;
+    ci=std::move(*citmp);
     delete citmp;
 
     uint16_t qlen, rlen;
@@ -648,10 +673,6 @@ void* tcpClientThread(int pipefd)
   drop:;
 
     vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
-    if (ci.fd >= 0) {
-      close(ci.fd);
-    }
-    ci.fd = -1;
 
     if (ds && outstanding) {
       outstanding = false;
@@ -683,13 +704,12 @@ void* tcpAcceptorThread(void* p)
   auto acl = g_ACL.getLocal();
   for(;;) {
     bool queuedCounterIncremented = false;
-    ConnectionInfo* ci = nullptr;
+    std::unique_ptr<ConnectionInfo> ci;
     tcpClientCountIncremented = false;
     try {
       socklen_t remlen = remote.getSocklen();
-      ci = new ConnectionInfo;
+      ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo);
       ci->cs = cs;
-      ci->fd = -1;
 #ifdef HAVE_ACCEPT4
       ci->fd = accept4(cs->tcpFD, (struct sockaddr*)&remote, &remlen, SOCK_NONBLOCK);
 #else
@@ -701,26 +721,17 @@ void* tcpAcceptorThread(void* p)
 
       if(!acl->match(remote)) {
        g_stats.aclDrops++;
-       close(ci->fd);
-       delete ci;
-       ci=nullptr;
        vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
        continue;
       }
 
 #ifndef HAVE_ACCEPT4
       if (!setNonBlocking(ci->fd)) {
-        close(ci->fd);
-        delete ci;
-        ci=nullptr;
         continue;
       }
 #endif
       setTCPNoDelay(ci->fd);  // disable NAGLE
       if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
-        close(ci->fd);
-        delete ci;
-        ci=nullptr;
         vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
         continue;
       }
@@ -729,9 +740,6 @@ void* tcpAcceptorThread(void* p)
         std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
 
         if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
-          close(ci->fd);
-          delete ci;
-          ci=nullptr;
           vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
           continue;
         }
@@ -745,14 +753,19 @@ void* tcpAcceptorThread(void* p)
       int pipe = g_tcpclientthreads->getThread();
       if (pipe >= 0) {
         queuedCounterIncremented = true;
-        writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
+        auto tmp = ci.release();
+        try {
+          writen2WithTimeout(pipe, &tmp, sizeof(tmp), 0);
+        }
+        catch(...) {
+          delete tmp;
+          tmp = nullptr;
+          throw;
+        }
       }
       else {
         g_tcpclientthreads->decrementQueuedCount();
         queuedCounterIncremented = false;
-        close(ci->fd);
-        delete ci;
-        ci=nullptr;
         if(tcpClientCountIncremented) {
           decrementTCPClientCount(remote);
         }
@@ -760,13 +773,9 @@ void* tcpAcceptorThread(void* p)
     }
     catch(std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
-      if(ci && ci->fd >= 0) 
-       close(ci->fd);
       if(tcpClientCountIncremented) {
         decrementTCPClientCount(remote);
       }
-      delete ci;
-      ci = nullptr;
       if (queuedCounterIncremented) {
         g_tcpclientthreads->decrementQueuedCount();
       }