]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: Revise interface creation wait
authorSimon Rozman <simon@rozman.si>
Thu, 7 Mar 2019 14:19:27 +0000 (15:19 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 7 Mar 2019 20:12:20 +0000 (21:12 +0100)
DIF_INSTALLDEVICE returns almost immediately, while the device
installation continues in the background. It might take a while, before
all registry keys and values are populated.

Previously, wireguard-go waited for HKLM\SYSTEM\CurrentControlSet\
Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\<id> registry key
only.

Followed by a SetInterfaceName() method of Wintun struct which tried to
access HKLM\SYSTEM\CurrentControlSet\Control\Network\
{4D36E972-E325-11CE-BFC1-08002BE10318}\<id>\Connection registry key
might not be available yet.

This commit loops until both registry keys are available before
returning from CreateInterface() function.

Signed-off-by: Simon Rozman <simon@rozman.si>
tun/wintun/setupapi/setupapi_windows.go
tun/wintun/setupapi/setupapi_windows_test.go
tun/wintun/wintun_windows.go

index 71732a40e2ac88f5563e7c544c7f1bf523c2fcab..5f9e05c0564744cc2516f18ec85a17b6eda03388 100644 (file)
@@ -7,12 +7,14 @@ package setupapi
 
 import (
        "encoding/binary"
+       "errors"
        "fmt"
        "syscall"
        "unsafe"
 
        "golang.org/x/sys/windows"
        "golang.org/x/sys/windows/registry"
+       "golang.zx2c4.com/wireguard/tun/wintun/guid"
 )
 
 //sys  setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW
@@ -234,6 +236,33 @@ func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DI
        return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired)
 }
 
+// GetInterfaceID method returns network interface ID.
+func (deviceInfoSet DevInfo) GetInterfaceID(deviceInfoData *DevInfoData) (*windows.GUID, error) {
+       // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
+       key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, DICS_FLAG_GLOBAL, 0, DIREG_DRV, registry.READ)
+       if err != nil {
+               return nil, errors.New("Device-specific registry key open failed: " + err.Error())
+       }
+       defer key.Close()
+
+       // Read the NetCfgInstanceId value.
+       value, valueType, err := key.GetStringValue("NetCfgInstanceId")
+       if err != nil {
+               return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
+       }
+       if valueType != registry.SZ {
+               return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
+       }
+
+       // Convert to windows.GUID.
+       ifid, err := guid.FromString(value)
+       if err != nil {
+               return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
+       }
+
+       return ifid, nil
+}
+
 //sys  setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW
 
 // SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
index 30f36920d3bf6e0ae894e577ab6d1e1a48231c7d..c6f4a1552ada171373467aa7c059049aedc5af13 100644 (file)
@@ -291,6 +291,11 @@ func TestSetupDiOpenDevRegKey(t *testing.T) {
                        t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error())
                }
                defer key.Close()
+
+               _, err = devInfoList.GetInterfaceID(data)
+               if err != nil {
+                       t.Errorf("Error calling GetInterfaceID: %s", err.Error())
+               }
        }
 }
 
index 85d29f4776457057b5637b971581adfe86fd1179..69fd30cd912f0c9c73195e23a054d736c86c4492 100644 (file)
@@ -58,27 +58,24 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
 
        // Iterate.
        for index := 0; ; index++ {
-               // Get the device from the list.
+               // Get the device from the list. Should anything be wrong with this device, continue with next.
                deviceData, err := devInfoList.EnumDeviceInfo(index)
                if err != nil {
                        if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ {
                                break
                        }
-                       // Something is wrong with this device. Skip it.
                        continue
                }
 
                // Get interface ID.
-               ifid, err := getInterfaceID(devInfoList, deviceData, 1)
+               ifid, err := devInfoList.GetInterfaceID(deviceData)
                if err != nil {
-                       // Something is wrong with this device. Skip it.
                        continue
                }
 
                // Get interface name.
                ifname2, err := ((*Wintun)(ifid)).GetInterfaceName()
                if err != nil {
-                       // Something is wrong with this device. Skip it.
                        continue
                }
 
@@ -243,8 +240,74 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                        rebootRequired = true
                }
 
-               // Get network interface ID from registry. Retry for max 30sec.
-               ifid, err = getInterfaceID(devInfoList, deviceData, 30)
+               // Get network interface ID from registry. DIF_INSTALLDEVICE returns almost immediately,
+               // while the device installation continues in the background. It might take a while, before
+               // all registry keys and values are populated.
+               getInterfaceID := func() (*windows.GUID, error) {
+                       // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
+                       keyDev, err := devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ)
+                       if err != nil {
+                               return nil, errors.New("Device-specific registry key open failed: " + err.Error())
+                       }
+                       defer keyDev.Close()
+
+                       // Read the NetCfgInstanceId value.
+                       value, err := getRegStringValue(keyDev, "NetCfgInstanceId")
+                       if err != nil {
+                               if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+                                       return nil, err
+                               }
+
+                               return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
+                       }
+
+                       // Convert to windows.GUID.
+                       ifid, err := guid.FromString(value)
+                       if err != nil {
+                               return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
+                       }
+
+                       keyNetName := fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), value)
+                       keyNet, err := registry.OpenKey(registry.LOCAL_MACHINE, keyNetName, registry.QUERY_VALUE)
+                       if err != nil {
+                               if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+                                       return nil, err
+                               }
+
+                               return nil, errors.New(fmt.Sprintf("RegOpenKeyEx(\"%v\") failed: ", keyNetName) + err.Error())
+                       }
+                       defer keyNet.Close()
+
+                       // Query the interface name.
+                       _, valueType, err := keyNet.GetValue("Name", nil)
+                       if err != nil {
+                               if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+                                       return nil, err
+                               }
+
+                               return nil, errors.New("RegQueryValueEx(\"Name\") failed: " + err.Error())
+                       }
+                       switch valueType {
+                       case registry.SZ, registry.EXPAND_SZ:
+                       default:
+                               return nil, fmt.Errorf("Interface name registry value is not REG_SZ or REG_EXPAND_SZ (expected: %v or %v, provided: %v)", registry.SZ, registry.EXPAND_SZ, valueType)
+                       }
+
+                       // TUN interface is ready. (As far as we need it.)
+                       return ifid, nil
+               }
+               for numAttempts := 0; numAttempts < 30; numAttempts++ {
+                       ifid, err = getInterfaceID()
+                       if err != nil {
+                               if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+                                       // Wait and retry. TODO: Wait for a cancellable event instead.
+                                       time.Sleep(1000 * time.Millisecond)
+                                       continue
+                               }
+                       }
+
+                       break
+               }
        }
 
        if err == nil {
@@ -294,20 +357,18 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
 
        // Iterate.
        for index := 0; ; index++ {
-               // Get the device from the list.
+               // Get the device from the list. Should anything be wrong with this device, continue with next.
                deviceData, err := devInfoList.EnumDeviceInfo(index)
                if err != nil {
                        if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ {
                                break
                        }
-                       // Something is wrong with this device. Skip it.
                        continue
                }
 
                // Get interface ID.
-               ifid2, err := getInterfaceID(devInfoList, deviceData, 1)
+               ifid2, err := devInfoList.GetInterfaceID(deviceData)
                if err != nil {
-                       // Something is wrong with this device. Skip it.
                        continue
                }
 
@@ -367,54 +428,6 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
        return false, nil
 }
 
-// getInterfaceID returns network interface ID.
-//
-// After the device is created, it might take some time before the registry
-// key is populated. numAttempts parameter specifies the number of attempts
-// to read NetCfgInstanceId value from registry. A 1sec sleep is inserted
-// between retry attempts.
-//
-// Function returns the network interface ID.
-//
-func getInterfaceID(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, numAttempts int) (*windows.GUID, error) {
-       if numAttempts < 1 {
-               return nil, fmt.Errorf("Invalid numAttempts (expected: >=1, provided: %v)", numAttempts)
-       }
-
-       // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
-       key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ)
-       if err != nil {
-               return nil, errors.New("Device-specific registry key open failed: " + err.Error())
-       }
-       defer key.Close()
-
-       for {
-               // Read the NetCfgInstanceId value.
-               value, err := getRegStringValue(key, "NetCfgInstanceId")
-               if err != nil {
-                       if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
-                               numAttempts--
-                               if numAttempts > 0 {
-                                       // Wait and retry.
-                                       // TODO: Wait for a cancellable event instead.
-                                       time.Sleep(1000 * time.Millisecond)
-                                       continue
-                               }
-                       }
-
-                       return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
-               }
-
-               // Convert to windows.GUID.
-               ifid, err := guid.FromString(value)
-               if err != nil {
-                       return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
-               }
-
-               return ifid, err
-       }
-}
-
 //
 // GetInterfaceName returns network interface name.
 //