]> git.ipfire.org Git - thirdparty/open-vm-tools.git/commitdiff
Changes to common source files not directly applicable to open-vm-tools.
authorOliver Kurth <okurth@vmware.com>
Tue, 26 May 2020 22:32:55 +0000 (15:32 -0700)
committerOliver Kurth <okurth@vmware.com>
Tue, 26 May 2020 22:32:55 +0000 (15:32 -0700)
open-vm-tools/lib/include/poll.h
open-vm-tools/lib/poll/poll.c

index ef9281b49dde4e97cc5cd5fef5dfe276dd1c90c6..c90f5dcd127b344bbbd9772e02a93f0761c0d110 100644 (file)
@@ -1,5 +1,5 @@
 /*********************************************************
- * Copyright (C) 1998-2019 VMware, Inc. All rights reserved.
+ * Copyright (C) 1998-2020 VMware, Inc. All rights reserved.
  *
  * This program is free software; you can redistribute it and/or modify it
  * under the terms of the GNU Lesser General Public License as published
@@ -232,6 +232,12 @@ typedef Bool (*PollerErrorFn)(const char *errorStr);
  *      implementations are distinct from the core poll code.
  */
 
+
+/* Socket pair created with non-blocking mode */
+#define POLL_OPTIONS_SOCKET_PAIR_NONBLOCK_CONN  0x01
+
+typedef unsigned int SocketSpecialOpts;
+
 typedef struct PollOptions {
    Bool locked;           // Use internal MXUser for locking
    Bool allowFullQueue;   // Don't assert when device event queue is full.
@@ -239,6 +245,7 @@ typedef struct PollOptions {
    PollerFireWrapper fireWrapperFn;  // optional; may be useful for stats
    void *fireWrapperData; // optional
    PollerErrorFn errorFn; // optional; called upon unrecoverable error
+   SocketSpecialOpts pollSocketOpts;
 } PollOptions;
 
 
@@ -251,7 +258,7 @@ void Poll_InitCF(void);  // On top of CoreFoundation for OSX
 /*
  * Functions
  */
-int Poll_SocketPair(Bool vmci, Bool stream, int fds[2]);
+int Poll_SocketPair(Bool vmci, Bool stream, int fds[2], SocketSpecialOpts opts);
 void Poll_Loop(Bool loop, Bool *exit, PollClass c);
 void Poll_LoopTimeout(Bool loop, Bool *exit, PollClass c, int timeout);
 Bool Poll_LockingEnabled(void);
index 4ddd8637ecc5caa21a94598a47fae91778f6c28d..f53832aebade53851a266d5fb69b79a5b28d2086 100644 (file)
@@ -42,6 +42,7 @@
    #include <winsock2.h>
    #include <ws2tcpip.h>
    #include "err.h"
+   #include "preference.h"
 #endif
 
 /*
@@ -437,25 +438,29 @@ PollSocketPairPrepare(Bool blocking,           // IN: blocking socket?
                       int socketCommType)      // IN: SOCK_STREAM or SOCK_DGRAM?
 {
    if (bind(dst, addr, addrlen) == SOCKET_ERROR) {
-      Log("%s: Could not bind socket.\n", __FUNCTION__);
+      Log("%s: Could not bind socket %d, error %d.\n",
+          __FUNCTION__, dst, WSAGetLastError());
       return FALSE;
    }
 
    if (!blocking) {
       unsigned long a = 1;
       if (ioctlsocket(*src, FIONBIO, &a) == SOCKET_ERROR) {
-         Log("%s: Could not make socket non-blocking.\n", __FUNCTION__);
+         Log("%s: Could not make socket %d non-blocking, error %d.\n",
+             __FUNCTION__, *src, WSAGetLastError());
          return FALSE;
       }
    }
 
    if (socketCommType == SOCK_STREAM && listen(dst, 1) == SOCKET_ERROR) {
-      Log("%s: Could not listen on a socket.\n", __FUNCTION__);
+      Log("%s: Could not listen on a socket %d, error %d.\n",
+          __FUNCTION__, dst, WSAGetLastError());
       return FALSE;
    }
 
    if (getsockname(dst, addr, &addrlen) == SOCKET_ERROR) {
-      Log("%s: getsockname() failed.\n", __FUNCTION__);
+      Log("%s: getsockname() failed for socket %d, error %d.\n",
+         __FUNCTION__, dst, WSAGetLastError());
       return FALSE;
    }
 
@@ -483,16 +488,76 @@ static Bool
 PollSocketPairConnect(Bool blocking,           // IN: blocking socket?
                       struct sockaddr *addr,   // IN: the address connected to
                       int addrlen,             // IN: length of struct sockaddr
-                      SOCKET *s)               // IN: connecting socket
+                      SOCKET *s,               // IN: connecting socket
+                      SocketSpecialOpts opts)   // IN: socket special options
 {
-   if (connect(*s, addr, addrlen) == SOCKET_ERROR) {
+   if (blocking && (opts & POLL_OPTIONS_SOCKET_PAIR_NONBLOCK_CONN)) {
+      /* Change blocking socket to non-blocking socket for timeout */
+      unsigned long unblock = 1;
+      if (ioctlsocket(*s, FIONBIO, &unblock) == SOCKET_ERROR) {
+         Log("%s: Set socket %d to non-blocking mode failed, error %d.\n",
+             __FUNCTION__, *s, WSAGetLastError());
+         return FALSE;
+      }
+
+      if (connect(*s, addr, addrlen) == SOCKET_ERROR) {
+         WSAPOLLFD pollFds[1];
+         /* wait timeout seconds */
+         unsigned int timeout = Preference_GetLong(3,
+            "pref.wsa.socket.pair.connect.timeout.seconds") * 1000;
+         int ret = WSAGetLastError();
+         if (ret != WSAEWOULDBLOCK) {
+            /* connection failed */
+            Log("%s: Non-blocking socket %d could not connect to a local "
+                "socket, error %d.\n", __FUNCTION__, *s, WSAGetLastError());
+            return FALSE;
+         }
+
+         pollFds[0].fd = *s;
+         pollFds[0].events = POLLWRNORM;
+         pollFds[0].revents = 0;
+
+         ret = WSAPoll(pollFds, 1, timeout);
+         if (ret <= 0) {
+            /* WSAPoll failed or connection timed out */
+            if (ret == 0) {
+                WSASetLastError(WSAETIMEDOUT);
+            }
+            Log("%s: Non-blocking socket %d connects to a local socket "
+                "failed, error %d.\n", __FUNCTION__, *s, WSAGetLastError());
+            return FALSE;
+         }
+
+         if ((pollFds[0].revents &
+            (POLLWRNORM|POLLERR|POLLHUP)) != POLLWRNORM) {
+            /* connection failed */
+            int error = 0;
+            int errLen = sizeof error;
+            getsockopt(*s, SOL_SOCKET, SO_ERROR, (char *)&error, &errLen);
+            WSASetLastError(error);
+            Log("%s: Non-blocking socket %d connect to a local socket failed, "
+                "error %d.\n", __FUNCTION__, *s, WSAGetLastError());
+            return FALSE;
+         }
+      }
+      /* connection successful */
+      Log("%s: Non-blocking socket %d connected successfully with "
+          "socket type %d.\n", __FUNCTION__, *s, addr->sa_family);
+      unblock = 0;
+      if (ioctlsocket(*s, FIONBIO, &unblock) == SOCKET_ERROR) {
+         Log("%s: Non-blocking socket %d restored to blocking mode failed, "
+             "error %d.\n", __FUNCTION__, *s, WSAGetLastError());
+         return FALSE;
+      }
+   } else if (connect(*s, addr, addrlen) == SOCKET_ERROR) {
       if (blocking || WSAGetLastError() != WSAEWOULDBLOCK) {
-         Log("%s: Could not connect to a local socket.\n", __FUNCTION__);
+         Log("%s: socket %d could not connect to a local socket, "
+             "error %d.\n", __FUNCTION__, *s, WSAGetLastError());
          return FALSE;
       }
-   } else if (!blocking) {
-      Log("%s: non-blocking socket connected immediately!\n", __FUNCTION__);
-      return FALSE;
+   } else {
+      Log("%s: non-blocking socket %d connected immediately!\n",
+          __FUNCTION__, *s);
    }
 
    return TRUE;
@@ -549,7 +614,8 @@ PollSocketPairConnecting(sa_family_t sa_family,    // IN: socket family type
                          struct sockaddr *addr,    // IN: the address connected to
                          int addrlen,              // IN: length of struct sockaddr
                          Bool blocking,            // IN: blocking socket?
-                         SOCKET *s)                // OUT: connecting socket
+                         SOCKET *s,                // OUT: connecting socket
+                         SocketSpecialOpts opts)   // IN: socket special options
 {
    SOCKET temp = INVALID_SOCKET;
 
@@ -570,14 +636,15 @@ PollSocketPairConnecting(sa_family_t sa_family,    // IN: socket family type
    }
 
    if (!PollSocketPairPrepare(blocking, s, temp, addr, addrlen, socketCommType)) {
-      Log("%s: Could not prepare the socket pair for the following connecting,\
-          socket type: %d\n", __FUNCTION__, sa_family);
+      Log("%s: Could not prepare the socket pair for the following "
+          "connecting, socket type: %d, sockets: %d, %d.\n",
+          __FUNCTION__, sa_family, *s, temp);
       goto outCloseTemp;
    }
 
-   if (!PollSocketPairConnect(blocking, addr, addrlen, s)) {
-      Log("%s: Could not make socket pair connected, socket type: %d",
-          __FUNCTION__, sa_family);
+   if (!PollSocketPairConnect(blocking, addr, addrlen, s, opts)) {
+      Log("%s: Could not make socket pair connected, socket type: %d, "
+          "sockets: %d, %d.\n", __FUNCTION__, sa_family, *s, temp);
       goto outCloseTemp;
    }
 
@@ -613,7 +680,8 @@ out:
 static SOCKET
 PollIPv4SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOCK_DGRAM?
                                   Bool blocking,       // IN: blocking socket?
-                                  SOCKET *s)           // OUT: connecting socket
+                                  SOCKET *s,           // OUT: connecting socket
+                                  SocketSpecialOpts opts) // IN: socket special options
 {
    struct sockaddr_in iaddr;
    int addrlen;
@@ -625,7 +693,8 @@ PollIPv4SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOC
    iaddr.sin_port = 0;
 
    return PollSocketPairConnecting(iaddr.sin_family, socketCommType,
-                                   (struct sockaddr*) &iaddr, addrlen, blocking, s);
+                                   (struct sockaddr*) &iaddr, addrlen,
+                                   blocking, s, opts);
 }
 
 
@@ -651,7 +720,8 @@ PollIPv4SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOC
 static SOCKET
 PollIPv6SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOCK_DGRAM?
                                   Bool blocking,       // IN: blocking socket?
-                                  SOCKET *s)           // OUT: connecting socket
+                                  SOCKET *s,           // OUT: connecting socket
+                                  SocketSpecialOpts opts)  // IN: socket special options
 {
    struct sockaddr_in6 iaddr6;
    int addrlen;
@@ -663,7 +733,8 @@ PollIPv6SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOC
    iaddr6.sin6_port = 0;
 
    return PollSocketPairConnecting(iaddr6.sin6_family, socketCommType,
-                                   (struct sockaddr*) &iaddr6, addrlen, blocking, s);
+                                   (struct sockaddr*) &iaddr6, addrlen,
+                                   blocking, s, opts);
 }
 
 
@@ -702,7 +773,8 @@ PollVMCISocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOC
    vaddr.svm_cid = VMCISock_GetLocalCID();
 
    return PollSocketPairConnecting(vaddr.svm_family, socketCommType,
-                                   (struct sockaddr*) &vaddr, addrlen, blocking, s);
+                                   (struct sockaddr*) &vaddr, addrlen,
+                                   blocking, s, 0);
 }
 
 
@@ -729,7 +801,8 @@ static SOCKET
 PollSocketPairStartConnecting(Bool vmci,      // IN: vmci socket?
                               Bool stream,    // IN: stream socket?
                               Bool blocking,  // IN: blocking socket?
-                              SOCKET *s)      // OUT: connecting socket
+                              SOCKET *s,     // OUT: connecting socket
+                              SocketSpecialOpts opts)  // IN: socket special options
 {
    SOCKET temp = INVALID_SOCKET;
    int socketCommType = stream ? SOCK_STREAM : SOCK_DGRAM;
@@ -737,13 +810,14 @@ PollSocketPairStartConnecting(Bool vmci,      // IN: vmci socket?
    if (vmci) {
       temp = PollVMCISocketPairStartConnecting(socketCommType, blocking, s);
    } else {
-      temp = PollIPv6SocketPairStartConnecting(socketCommType, blocking, s);
+      temp = PollIPv6SocketPairStartConnecting(socketCommType, blocking,
+                                               s, opts);
 
       if (temp == INVALID_SOCKET) {
-         temp = PollIPv4SocketPairStartConnecting(socketCommType, blocking, s);
+         temp = PollIPv4SocketPairStartConnecting(socketCommType, blocking,
+                                                  s, opts);
       }
    }
-
    return temp;
 }
 
@@ -767,14 +841,16 @@ PollSocketPairStartConnecting(Bool vmci,      // IN: vmci socket?
 int
 Poll_SocketPair(Bool vmci,     // IN: create vmci pair?
                 Bool stream,   // IN: stream socket?
-                int fds[2])    // OUT: 2 sockets connected to each other
+                int fds[2],    // OUT: 2 sockets connected to each other
+                SocketSpecialOpts opts)  // IN: socket special options
 {
    SOCKET temp = INVALID_SOCKET;
 
    fds[0] = INVALID_SOCKET;
    fds[1] = INVALID_SOCKET;
 
-   temp = PollSocketPairStartConnecting(vmci, stream, TRUE, (SOCKET *)&fds[0]);
+   temp = PollSocketPairStartConnecting(vmci, stream, TRUE,
+                                        (SOCKET *)&fds[0], opts);
    if (temp == INVALID_SOCKET) {
       goto out;
    }
@@ -2051,7 +2127,7 @@ PollUnitTest_StateMachine(void *clientData) // IN: Unused
 #ifdef _WIN32
          socketPairs[i].fds[0] = INVALID_SOCKET;
          socketPairs[i].fds[1] = INVALID_SOCKET;
-         if (Poll_SocketPair(useVMCI, TRUE, socketPairs[i].fds) < 0) {
+         if (Poll_SocketPair(useVMCI, TRUE, socketPairs[i].fds, 0) < 0) {
             Warning("%s:   failure -- error creating socketpair, iteration %d\n",
                     __FUNCTION__, i);
             break;
@@ -2128,7 +2204,7 @@ PollUnitTest_StateMachine(void *clientData) // IN: Unused
          closesocket(fds[1]);
          fds[0] = INVALID_SOCKET;
          fds[1] = INVALID_SOCKET;
-         if (Poll_SocketPair(TRUE, TRUE, fds) < 0) {
+         if (Poll_SocketPair(TRUE, TRUE, fds, 0) < 0) {
             Warning("%s:   failure -- error creating vmci socketpair\n",
                     __FUNCTION__);
             state ++;
@@ -2171,7 +2247,7 @@ PollUnitTest_StateMachine(void *clientData) // IN: Unused
          closesocket(fds[1]);
          fds[0] = INVALID_SOCKET;
          fds[1] = INVALID_SOCKET;
-         if (Poll_SocketPair(FALSE, TRUE, fds) < 0) {
+         if (Poll_SocketPair(FALSE, TRUE, fds, 0) < 0) {
             Warning("%s:   failure -- error creating socketpair\n",
                     __FUNCTION__);
             state += 3;
@@ -2201,7 +2277,7 @@ PollUnitTest_StateMachine(void *clientData) // IN: Unused
       closesocket(fds[1]);
       fds[0] = INVALID_SOCKET;
       fds[1] = INVALID_SOCKET;
-      if (Poll_SocketPair(FALSE, TRUE, fds) < 0)
+      if (Poll_SocketPair(FALSE, TRUE, fds, 0) < 0)
    #else
       close(fds[0]);
       close(fds[1]);
@@ -2406,7 +2482,7 @@ PollUnitTest(Bool vmx)  // IN: use vmx-size poll queue
    }
    fds[0] = INVALID_SOCKET;
    fds[1] = INVALID_SOCKET;
-   if (Poll_SocketPair(FALSE, TRUE, fds) < 0) {
+   if (Poll_SocketPair(FALSE, TRUE, fds, 0) < 0) {
 #else
    fds[0] = -1;
    fds[1] = -1;