]> git.ipfire.org Git - thirdparty/open-vm-tools.git/commitdiff
Common source file change not directly applicable to open-vm-tools.
authorOliver Kurth <okurth@vmware.com>
Wed, 16 Jan 2019 22:53:04 +0000 (14:53 -0800)
committerOliver Kurth <okurth@vmware.com>
Wed, 16 Jan 2019 22:53:04 +0000 (14:53 -0800)
open-vm-tools/lib/poll/poll.c

index 374396e550ab9c803252c42c4dc3a29cad0bfffd..89245da31eca1e008820efc62d1dfd2bf334d2fb 100644 (file)
@@ -1,5 +1,5 @@
 /*********************************************************
- * Copyright (C) 1998-2017 VMware, Inc. All rights reserved.
+ * Copyright (C) 1998-2018 VMware, Inc. All rights reserved.
  *
  * This program is free software; you can redistribute it and/or modify it
  * under the terms of the GNU Lesser General Public License as published
@@ -414,15 +414,13 @@ Poll_CB_RTimeRemove(PollerFunction f,
 /*
  *-----------------------------------------------------------------------------
  *
- * PollSocketPairStartConnecting --
+ * PollSocketPairPrepare --
  *
- *      Helper function that does most of the work of creating
- *      a socket pair.
+ *      Do miscellaneous preparetion for the socket pair before connecting
  *
  * Results:
- *      Socket bound to a local address, and another connecting
- *      to that address.
- *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *      Socket bound to a local address, and another set properly.
+ *      TRUE if all preparetion succeed, otherwise FALSE.
  *
  * Side effects:
  *      None.
@@ -430,116 +428,326 @@ Poll_CB_RTimeRemove(PollerFunction f,
  *-----------------------------------------------------------------------------
  */
 
-static SOCKET
-PollSocketPairStartConnecting(Bool vmci,      // IN: vmci socket?
-                              Bool stream,    // IN: stream socket?
-                              Bool blocking,  // IN: blocking socket?
-                              SOCKET *s)      // OUT: connecting socket
+static Bool
+PollSocketPairPrepare(Bool blocking,           // IN: blocking socket?
+                      SOCKET *src,             // IN: client side socket
+                      SOCKET dst,              // IN: server side socket
+                      struct sockaddr *addr,   // IN: the address connected to
+                      int addrlen,             // IN: length of struct sockaddr
+                      int socketCommType)      // IN: SOCK_STREAM or SOCK_DGRAM?
 {
-   SOCKET temp = INVALID_SOCKET;
-   struct sockaddr *addr;
-   int addrlen;
-   struct sockaddr_vm vaddr;
-   struct sockaddr_in iaddr;
-   struct sockaddr_in6 iaddr6;
-   int savedError;
-   int socketCommType = stream ? SOCK_STREAM : SOCK_DGRAM;
-
-   if (vmci) {
-      addrlen = sizeof vaddr;
-      memset(&vaddr, 0, sizeof vaddr);
-      vaddr.svm_family = VMCISock_GetAFValue();
-      vaddr.svm_cid = VMADDR_CID_ANY;
-      vaddr.svm_port = VMADDR_PORT_ANY;
-
-      *s = socket(vaddr.svm_family, socketCommType, 0);
-      if (*s == INVALID_SOCKET) {
-         Log("%s: Could not create vmci socket.\n", __FUNCTION__);
-         goto out;
-      }
-      temp = socket(vaddr.svm_family, socketCommType, 0);
-      if (temp == INVALID_SOCKET) {
-         Log("%s: Could not create second vmci socket.\n", __FUNCTION__);
-         goto out;
-      }
-      addr = (struct sockaddr *)&vaddr;
-   } else {
-      // First try create a IPv6 socket
-      *s = socket(AF_INET6, socketCommType, 0);
-
-      if (*s != INVALID_SOCKET) {
-         // Set to IPv6 loopback address
-         memset(&iaddr6, 0, sizeof iaddr6);
-         iaddr6.sin6_family = AF_INET6;
-         iaddr6.sin6_addr = in6addr_loopback;
-         iaddr6.sin6_port = 0;
-         addr = (struct sockaddr *)&iaddr6;
-         addrlen = sizeof iaddr6;
-      } else {
-         // Try create the socket again, but using IPv4
-         *s = socket(AF_INET, socketCommType, 0);
-         if (*s == INVALID_SOCKET) {
-            Log("%s: Could not create inet socket.\n", __FUNCTION__);
-            goto out;
-         }
-
-         // Set to IPv4 loopback address
-         memset(&iaddr, 0, sizeof iaddr);
-         iaddr.sin_family = AF_INET;
-         iaddr.sin_addr = in4addr_loopback;
-         iaddr.sin_port = 0;
-         addr = (struct sockaddr *)&iaddr;
-         addrlen = sizeof iaddr;
-      }
-
-      temp = socket(addr->sa_family, socketCommType, 0);
-      if (temp == INVALID_SOCKET) {
-         Log("%s: Could not create second inet socket.\n", __FUNCTION__);
-         goto out;
-      }
-   }
-   if (bind(temp, addr, addrlen) == SOCKET_ERROR) {
+   if (bind(dst, addr, addrlen) == SOCKET_ERROR) {
       Log("%s: Could not bind socket.\n", __FUNCTION__);
-      goto outCloseTemp;
+      return FALSE;
    }
+
    if (!blocking) {
       unsigned long a = 1;
-      if (ioctlsocket(*s, FIONBIO, &a) == SOCKET_ERROR) {
+      if (ioctlsocket(*src, FIONBIO, &a) == SOCKET_ERROR) {
          Log("%s: Could not make socket non-blocking.\n", __FUNCTION__);
-         goto outCloseTemp;
+         return FALSE;
       }
    }
-   if (stream && listen(temp, 1) == SOCKET_ERROR) {
+
+   if (socketCommType == SOCK_STREAM && listen(dst, 1) == SOCKET_ERROR) {
       Log("%s: Could not listen on a socket.\n", __FUNCTION__);
-      goto outCloseTemp;
+      return FALSE;
    }
-   if (getsockname(temp, addr, &addrlen) == SOCKET_ERROR) {
+
+   if (getsockname(dst, addr, &addrlen) == SOCKET_ERROR) {
       Log("%s: getsockname() failed.\n", __FUNCTION__);
-      goto outCloseTemp;
-   }
-   if (vmci) {
-      vaddr.svm_cid = VMCISock_GetLocalCID();
+      return FALSE;
    }
+
+   return TRUE;
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollSocketPairConnect --
+ *
+ *      Connects a socket to a given address.
+ *
+ * Results:
+ *      TRUE if connecting successfully, otherwise FALSE is returned.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static Bool
+PollSocketPairConnect(Bool blocking,           // IN: blocking socket?
+                      struct sockaddr *addr,   // IN: the address connected to
+                      int addrlen,             // IN: length of struct sockaddr
+                      SOCKET *s)               // IN: connecting socket
+{
    if (connect(*s, addr, addrlen) == SOCKET_ERROR) {
       if (blocking || WSAGetLastError() != WSAEWOULDBLOCK) {
          Log("%s: Could not connect to a local socket.\n", __FUNCTION__);
-         goto outCloseTemp;
+         return FALSE;
       }
    } else if (!blocking) {
       Log("%s: non-blocking socket connected immediately!\n", __FUNCTION__);
+      return FALSE;
+   }
+
+   return TRUE;
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollSocketClose --
+ *
+ *      Close the socket, and restore the original last error.
+ *
+ * Results:
+ *      Socket is closed, original last error is restored.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static INLINE void
+PollSocketClose(SOCKET sock) {  // IN: the socket is being closed
+   int savedError = GetLastError();
+   closesocket(sock);
+   SetLastError(savedError);
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollSocketPairConnecting --
+ *
+ *      Given necessary information, like socket family type, communication
+ *      type, socket address and socket type, this function initialize a socket
+ *      pair and make them connect to each other.
+ *
+ * Results:
+ *      Socket bound to a given address, and another connecting
+ *      to that address.
+ *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static SOCKET
+PollSocketPairConnecting(sa_family_t sa_family,    // IN: socket family type
+                         int socketCommType,       // IN: SOCK_STREAM or SOCK_DGRAM?
+                         struct sockaddr *addr,    // IN: the address connected to
+                         int addrlen,              // IN: length of struct sockaddr
+                         Bool blocking,            // IN: blocking socket?
+                         SOCKET *s)                // OUT: connecting socket
+{
+   SOCKET temp = INVALID_SOCKET;
+
+   *s = socket(sa_family, socketCommType, 0);
+   if (*s == INVALID_SOCKET) {
+      Log("%s: Could not create socket, socket family: %d.\n", __FUNCTION__,
+          sa_family);
+      goto out;
+   }
+
+   temp = socket(sa_family, socketCommType, 0);
+   if (temp == INVALID_SOCKET) {
+      PollSocketClose(*s);
+      *s = INVALID_SOCKET;
+      Log("%s: Could not create second socket, socket family: %d.\n",
+          __FUNCTION__, sa_family);
+      goto out;
+   }
+
+   if (!PollSocketPairPrepare(blocking, s, temp, addr, addrlen, socketCommType)) {
+      Log("%s: Could not prepare the socket pair for the following connecting,\
+          socket type: %d\n", __FUNCTION__, sa_family);
       goto outCloseTemp;
    }
+
+   if (!PollSocketPairConnect(blocking, addr, addrlen, s)) {
+      Log("%s: Could not make socket pair connected, socket type: %d",
+          __FUNCTION__, sa_family);
+      goto outCloseTemp;
+   }
+
    return temp;
 
 outCloseTemp:
-   savedError = GetLastError();
-   closesocket(temp);
-   SetLastError(savedError);
+   PollSocketClose(temp);
+
 out:
    return INVALID_SOCKET;
 }
 
 
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollIPv4SocketPairStartConnecting --
+ *
+ *      As one of the PollXXXSocketPairStartConnecting family, this function
+ *      creates an *IPv4* socket pair.
+ *
+ * Results:
+ *      Socket bound to a local address, and another connecting
+ *      to that address.
+ *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static SOCKET
+PollIPv4SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOCK_DGRAM?
+                                  Bool blocking,       // IN: blocking socket?
+                                  SOCKET *s)           // OUT: connecting socket
+{
+   struct sockaddr_in iaddr;
+   int addrlen;
+
+   addrlen = sizeof iaddr;
+   memset(&iaddr, 0, addrlen);
+   iaddr.sin_family = AF_INET;
+   iaddr.sin_addr = in4addr_loopback;
+   iaddr.sin_port = 0;
+
+   return PollSocketPairConnecting(iaddr.sin_family, socketCommType,
+                                   (struct sockaddr*) &iaddr, addrlen, blocking, s);
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollIPv6SocketPairStartConnecting --
+ *
+ *      As one of the PollXXXSocketPairStartConnecting family, this function
+ *      creates an *IPv6* socket pair.
+ *
+ * Results:
+ *      Socket bound to a local address, and another connecting
+ *      to that address.
+ *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static SOCKET
+PollIPv6SocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOCK_DGRAM?
+                                  Bool blocking,       // IN: blocking socket?
+                                  SOCKET *s)           // OUT: connecting socket
+{
+   struct sockaddr_in6 iaddr6;
+   int addrlen;
+
+   addrlen = sizeof iaddr6;
+   memset(&iaddr6, 0, addrlen);
+   iaddr6.sin6_family = AF_INET6;
+   iaddr6.sin6_addr = in6addr_loopback;
+   iaddr6.sin6_port = 0;
+
+   return PollSocketPairConnecting(iaddr6.sin6_family, socketCommType,
+                                   (struct sockaddr*) &iaddr6, addrlen, blocking, s);
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollVMCISocketPairStartConnecting --
+ *
+ *      As one of the PollXXXSocketPairStartConnecting family, this function
+ *      creates a *VMCI* socket pair.
+ *
+ * Results:
+ *      Socket bound to a local address, and another connecting
+ *      to that address.
+ *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static SOCKET
+PollVMCISocketPairStartConnecting(int socketCommType,  // IN: SOCK_STREAM or SOCK_DGRAM?
+                                  Bool blocking,       // IN: blocking socket?
+                                  SOCKET *s)           // OUT: connecting socket
+{
+   struct sockaddr_vm vaddr;
+   int addrlen;
+
+   addrlen = sizeof vaddr;
+   memset(&vaddr, 0, addrlen);
+   vaddr.svm_family = VMCISock_GetAFValue();
+   vaddr.svm_cid = VMADDR_CID_ANY;
+   vaddr.svm_port = VMADDR_PORT_ANY;
+   vaddr.svm_cid = VMCISock_GetLocalCID();
+
+   return PollSocketPairConnecting(vaddr.svm_family, socketCommType,
+                                   (struct sockaddr*) &vaddr, addrlen, blocking, s);
+}
+
+
+/*
+ *-----------------------------------------------------------------------------
+ *
+ * PollSocketPairStartConnecting --
+ *
+ *      Helper function that does most of the work of creating
+ *      a socket pair.
+ *
+ * Results:
+ *      Socket bound to a local address, and another connecting
+ *      to that address.
+ *      INVALID_SOCKET on error.  Use WSAGetLastError() for detail.
+ *
+ * Side effects:
+ *      None.
+ *
+ *-----------------------------------------------------------------------------
+ */
+
+static SOCKET
+PollSocketPairStartConnecting(Bool vmci,      // IN: vmci socket?
+                              Bool stream,    // IN: stream socket?
+                              Bool blocking,  // IN: blocking socket?
+                              SOCKET *s)      // OUT: connecting socket
+{
+   SOCKET temp = INVALID_SOCKET;
+   int socketCommType = stream ? SOCK_STREAM : SOCK_DGRAM;
+
+   if (vmci) {
+      temp = PollVMCISocketPairStartConnecting(socketCommType, blocking, s);
+   } else {
+      temp = PollIPv6SocketPairStartConnecting(socketCommType, blocking, s);
+
+      if (temp == INVALID_SOCKET) {
+         temp = PollIPv4SocketPairStartConnecting(socketCommType, blocking, s);
+      }
+   }
+
+   return temp;
+}
+
+
 /*
  *-----------------------------------------------------------------------------
  *
@@ -626,7 +834,7 @@ out:
    #define DROP_LOCK(_lock)
 #endif
 
-/* 
+/*
  * Make this queue length a little bit less than poll implementation's max
  * to allow for some sockets in the test program itself.
  */
@@ -1907,7 +2115,7 @@ PollUnitTest_StateMachine(void *clientData) // IN: Unused
    case 44:
 #if POLL_TESTVMCI
 
-      /* 
+      /*
        * The following tests only work inside the guest,
        * as stream VSockets are unsuported for host<->host communication.
        */