]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/iputils.cc
when we sendmsg, the socket may not be done connecting yet
[thirdparty/pdns.git] / pdns / iputils.cc
index d6facc02958f314afa2bf0bcc25e26fc317638a7..b7cad40319320f301d95b869d736ef5a1e383f68 100644 (file)
@@ -32,6 +32,10 @@ static void RuntimeError(const boost::format& fmt)
   throw runtime_error(fmt.str());
 }
 
+static void NetworkErr(const boost::format& fmt)
+{
+  throw NetworkError(fmt.str());
+}
 
 int SSocket(int family, int type, int flags)
 {
@@ -43,7 +47,7 @@ int SSocket(int family, int type, int flags)
 
 int SConnect(int sockfd, const ComboAddress& remote)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     RuntimeError(boost::format("connecting socket to %s: %s") % remote.toStringWithPort() % strerror(savederrno));
@@ -53,10 +57,14 @@ int SConnect(int sockfd, const ComboAddress& remote)
 
 int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     if (savederrno == EINPROGRESS) {
+      if (timeout <= 0) {
+        return savederrno;
+      }
+
       /* we wait until the connection has been established */
       bool error = false;
       bool disconnected = false;
@@ -66,30 +74,30 @@ int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
           savederrno = 0;
           socklen_t errlen = sizeof(savederrno);
           if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, (void *)&savederrno, &errlen) == 0) {
-            RuntimeError(boost::format("connecting to %s failed: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
+            NetworkErr(boost::format("connecting to %s failed: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
           }
           else {
-            RuntimeError(boost::format("connecting to %s failed") % remote.toStringWithPort());
+            NetworkErr(boost::format("connecting to %s failed") % remote.toStringWithPort());
           }
         }
         if (disconnected) {
-          RuntimeError(boost::format("%s closed the connection") % remote.toStringWithPort());
+          NetworkErr(boost::format("%s closed the connection") % remote.toStringWithPort());
         }
         return 0;
       }
       else if (res == 0) {
-        RuntimeError(boost::format("timeout while connecting to %s") % remote.toStringWithPort());
+        NetworkErr(boost::format("timeout while connecting to %s") % remote.toStringWithPort());
       } else if (res < 0) {
         savederrno = errno;
-        RuntimeError(boost::format("waiting to connect to %s: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
+        NetworkErr(boost::format("waiting to connect to %s: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
       }
     }
     else {
-      RuntimeError(boost::format("connecting to %s: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
+      NetworkErr(boost::format("connecting to %s: %s") % remote.toStringWithPort() % string(strerror(savederrno)));
     }
   }
 
-  return ret;
+  return 0;
 }
 
 int SBind(int sockfd, const ComboAddress& local)
@@ -145,8 +153,12 @@ bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv)
 }
 bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destination)
 {
-  memset(destination, 0, sizeof(*destination));
+  destination->reset();
+#ifdef __NetBSD__
+  struct cmsghdr* cmsg;
+#else
   const struct cmsghdr* cmsg;
+#endif
   for (cmsg = CMSG_FIRSTHDR(msgh); cmsg != NULL; cmsg = CMSG_NXTHDR(const_cast<struct msghdr*>(msgh), const_cast<struct cmsghdr*>(cmsg))) {
 #if defined(IP_PKTINFO)
      if ((cmsg->cmsg_level == IPPROTO_IP) && (cmsg->cmsg_type == IP_PKTINFO)) {
@@ -257,40 +269,112 @@ void ComboAddress::truncate(unsigned int bits) noexcept
   *place &= (~((1<<bitsleft)-1));
 }
 
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf)
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags)
 {
+  int remainingTime = totalTimeout;
+  time_t start = 0;
+  if (totalTimeout) {
+    start = time(nullptr);
+  }
+
   struct msghdr msgh;
   struct iovec iov;
   char cbuf[256];
+
+  /* Set up iov and msgh structures. */
+  memset(&msgh, 0, sizeof(struct msghdr));
+  msgh.msg_control = nullptr;
+  msgh.msg_controllen = 0;
+  if (dest) {
+    msgh.msg_name = reinterpret_cast<void*>(const_cast<ComboAddress*>(dest));
+    msgh.msg_namelen = dest->getSocklen();
+  }
+  else {
+    msgh.msg_name = nullptr;
+    msgh.msg_namelen = 0;
+  }
+
+  msgh.msg_flags = 0;
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  iov.iov_base = reinterpret_cast<void*>(const_cast<char*>(buffer));
+  iov.iov_len = len;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  msgh.msg_flags = 0;
+
+  size_t sent = 0;
   bool firstTry = true;
-  fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast<char*>(buffer), len, &dest);
-  addCMsgSrcAddr(&msgh, cbuf, &local, localItf);
 
   do {
-    ssize_t written = sendmsg(fd, &msgh, 0);
 
-    if (written > 0)
-      return written;
+#ifdef MSG_FASTOPEN
+    if (flags & MSG_FASTOPEN && firstTry == false) {
+      flags &= ~MSG_FASTOPEN;
+    }
+#endif /* MSG_FASTOPEN */
+
+    ssize_t res = sendmsg(fd, &msgh, flags);
 
-    if (errno == EAGAIN) {
-      if (firstTry) {
-        int res = waitForRWData(fd, false, timeout, 0);
-        if (res > 0) {
-          /* there is room available */
-          firstTry = false;
+    if (res > 0) {
+      size_t written = static_cast<size_t>(res);
+      sent += written;
+
+      if (sent == len) {
+        return sent;
+      }
+
+      /* partial write */
+      iov.iov_len -= written;
+      iov.iov_base = reinterpret_cast<void*>(reinterpret_cast<char*>(iov.iov_base) + written);
+      written = 0;
+    }
+    else if (res == -1) {
+      if (errno == EINTR) {
+        continue;
+      }
+      else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS || errno == ENOTCONN) {
+        /* EINPROGRESS might happen with non blocking socket,
+           especially with TCP Fast Open */
+        if (totalTimeout <= 0 && idleTimeout <= 0) {
+          return sent;
         }
-        else if (res == 0) {
+
+        if (firstTry) {
+          int res = waitForRWData(fd, false, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime, 0);
+          if (res > 0) {
+            /* there is room available */
+            firstTry = false;
+          }
+          else if (res == 0) {
+            throw runtime_error("Timeout while waiting to write data");
+          } else {
+            throw runtime_error("Error while waiting for room to write data");
+          }
+        }
+        else {
           throw runtime_error("Timeout while waiting to write data");
-        } else {
-          throw runtime_error("Error while waiting for room to write data");
         }
       }
       else {
-        throw runtime_error("Timeout while waiting to write data");
+        unixDie("failed in sendMsgWithTimeout");
       }
     }
-    else {
-      unixDie("failed in write2WithTimeout");
+    if (totalTimeout) {
+      time_t now = time(nullptr);
+      int elapsed = now - start;
+      if (elapsed >= remainingTime) {
+        throw runtime_error("Timeout while sending data");
+      }
+      start = now;
+      remainingTime -= elapsed;
     }
   }
   while (firstTry);