]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: IpConfig is a MULTI_SZ, and fix errors
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 10 May 2019 16:01:47 +0000 (18:01 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 10 May 2019 16:06:49 +0000 (18:06 +0200)
tun/wintun/registry/registry_windows.go
tun/wintun/wintun_windows.go

index 65da6bfc1375f3b241d50311ec403c278c350196..8c63c9b4e05789ded88c25932aa21610d48c87ac 100644 (file)
@@ -111,7 +111,7 @@ func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
 //
 // Key must be opened with at least QUERY_VALUE|KEY_NOTIFY access.
 //
-func getStringValueRetry(key registry.Key, name string, timeout time.Duration) (string, uint32, error) {
+func getStringValueRetry(key registry.Key, name string, timeout time.Duration, useFirstFromMulti bool) (string, uint32, error) {
        runtime.LockOSThread()
        defer runtime.UnlockOSThread()
 
@@ -128,8 +128,15 @@ func getStringValueRetry(key registry.Key, name string, timeout time.Duration) (
                        return "", 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
                }
 
-               value, valueType, err := key.GetStringValue(name)
-               if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
+               var value string
+               var values []string
+               var valueType uint32
+               if !useFirstFromMulti {
+                       value, valueType, err = key.GetStringValue(name)
+               } else {
+                       values, valueType, err = key.GetStringsValue(name)
+               }
+               if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND || (useFirstFromMulti && len(values) == 0) {
                        timeout := time.Until(deadline) / time.Millisecond
                        if timeout < 0 {
                                timeout = 0
@@ -144,7 +151,11 @@ func getStringValueRetry(key registry.Key, name string, timeout time.Duration) (
                } else if err != nil {
                        return "", 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
                } else {
-                       return value, valueType, nil
+                       if !useFirstFromMulti {
+                               return value, valueType, nil
+                       } else {
+                               return values[0], registry.SZ, nil
+                       }
                }
        }
 }
@@ -179,7 +190,14 @@ func expandString(value string, valueType uint32, err error) (string, error) {
 // Should expanding fail, original string value and nil error are returned.
 //
 func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
-       return expandString(getStringValueRetry(key, name, timeout))
+       return expandString(getStringValueRetry(key, name, timeout, false))
+}
+
+//
+// Same as GetStringValueWait, but returns the first from a MULTI_SZ.
+//
+func GetFirstStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
+       return expandString(getStringValueRetry(key, name, timeout, true))
 }
 
 //
index a73f5f2756c68eb647568293b7f46774ee658996..8163f94c671dac81a8930d02a7d8ac570d719f78 100644 (file)
@@ -44,14 +44,14 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
        // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
        key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
        if err != nil {
-               return nil, errors.New("Device-specific registry key open failed: " + err.Error())
+               return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
        }
        defer key.Close()
 
        // Read the NetCfgInstanceId value.
        valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId")
        if err != nil {
-               return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
+               return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err)
        }
 
        // Convert to windows.GUID.
@@ -63,13 +63,13 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
        // Read the NetLuidIndex value.
        luidIdx, _, err := key.GetIntegerValue("NetLuidIndex")
        if err != nil {
-               return nil, errors.New("RegQueryValue(\"NetLuidIndex\") failed: " + err.Error())
+               return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err)
        }
 
        // Read the NetLuidIndex value.
        ifType, _, err := key.GetIntegerValue("*IfType")
        if err != nil {
-               return nil, errors.New("RegQueryValue(\"*IfType\") failed: " + err.Error())
+               return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
        }
 
        return &Wintun{
@@ -95,7 +95,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
        // Create a list of network devices.
        devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName)
        if err != nil {
-               return nil, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error())
+               return nil, fmt.Errorf("SetupDiGetClassDevsEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
        }
        defer devInfoList.Close()
 
@@ -134,7 +134,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
                        const driverType = setupapi.SPDIT_COMPATDRIVER
                        err = devInfoList.BuildDriverInfoList(deviceData, driverType)
                        if err != nil {
-                               return nil, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error())
+                               return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
                        }
                        defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
 
@@ -188,44 +188,44 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
        // Create an empty device info set for network adapter device class.
        devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, hwndParent, machineName)
        if err != nil {
-               return nil, false, errors.New(fmt.Sprintf("SetupDiCreateDeviceInfoListEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error())
+               return nil, false, fmt.Errorf("SetupDiCreateDeviceInfoListEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
        }
 
        // Get the device class name from GUID.
        className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, machineName)
        if err != nil {
-               return nil, false, errors.New(fmt.Sprintf("SetupDiClassNameFromGuidEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error())
+               return nil, false, fmt.Errorf("SetupDiClassNameFromGuidEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
        }
 
        // Create a new device info element and add it to the device info set.
        deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, hwndParent, setupapi.DICD_GENERATE_ID)
        if err != nil {
-               return nil, false, errors.New("SetupDiCreateDeviceInfo failed: " + err.Error())
+               return nil, false, fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
        }
 
        // Set a device information element as the selected member of a device information set.
        err = devInfoList.SetSelectedDevice(deviceData)
        if err != nil {
-               return nil, false, errors.New("SetupDiSetSelectedDevice failed: " + err.Error())
+               return nil, false, fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
        }
 
        // Set Plug&Play device hardware ID property.
        err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
        if err != nil {
-               return nil, false, errors.New("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: " + err.Error())
+               return nil, false, fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
        }
 
        // Search for the driver.
        const driverType = setupapi.SPDIT_CLASSDRIVER
-       err = devInfoList.BuildDriverInfoList(deviceData, driverType)
+       err = devInfoList.BuildDriverInfoList(deviceData, driverType) //TODO: This takes ~510ms
        if err != nil {
-               return nil, false, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error())
+               return nil, false, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
        }
        defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
 
        driverDate := windows.Filetime{}
        driverVersion := uint64(0)
-       for index := 0; ; index++ {
+       for index := 0; ; index++ { //TODO: This loop takes ~600ms
                // Get a driver from the list.
                driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index)
                if err != nil {
@@ -266,7 +266,7 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
        // Call appropriate class installer.
        err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
        if err != nil {
-               return nil, false, errors.New("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: " + err.Error())
+               return nil, false, fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
        }
 
        // Register device co-installers if any. (Ignore errors)
@@ -275,16 +275,16 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
        // Install interfaces if any. (Ignore errors)
        devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
 
-       var wintun *Wintun
-       var rebootRequired bool
-       var key registry.Key
-
        // Install the device.
        err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
        if err != nil {
-               err = errors.New("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: " + err.Error())
+               err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
        }
 
+       var wintun *Wintun
+       var rebootRequired bool
+       var key registry.Key
+
        if err == nil {
                // Check if a system reboot is required. (Ignore errors)
                if ret, _ := checkReboot(devInfoList, deviceData); ret {
@@ -341,19 +341,23 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                        wintun.GetTcpipAdapterRegKeyName(), registry.QUERY_VALUE|registryEx.KEY_NOTIFY,
                        waitForRegistryTimeout)
                if err == nil {
-                       _, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout)
+                       _, err = registryEx.GetFirstStringValueWait(key, "IpConfig", waitForRegistryTimeout)
                        key.Close()
                }
        }
 
+       var tcpipInterfaceRegKeyName string
        if err == nil {
-               // Wait for TCP/IP interface registry key to emerge.
-               key, err = registryEx.OpenKeyWait(
-                       registry.LOCAL_MACHINE,
-                       wintun.GetTcpipInterfaceRegKeyName(), registry.QUERY_VALUE,
-                       waitForRegistryTimeout)
+               tcpipInterfaceRegKeyName, err = wintun.GetTcpipInterfaceRegKeyName()
                if err == nil {
-                       key.Close()
+                       // Wait for TCP/IP interface registry key to emerge.
+                       key, err = registryEx.OpenKeyWait(
+                               registry.LOCAL_MACHINE,
+                               tcpipInterfaceRegKeyName, registry.QUERY_VALUE,
+                               waitForRegistryTimeout)
+                       if err == nil {
+                               key.Close()
+                       }
                }
        }
 
@@ -363,9 +367,9 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
 
        if err == nil {
                // Disable dead gateway detection on our interface.
-               key, err = registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetTcpipInterfaceRegKeyName(), registry.SET_VALUE)
+               key, err = registry.OpenKey(registry.LOCAL_MACHINE, tcpipInterfaceRegKeyName, registry.SET_VALUE)
                if err != nil {
-                       err = errors.New("Error opening interface-specific TCP/IP network registry key: " + err.Error())
+                       err = fmt.Errorf("Error opening interface-specific TCP/IP network registry key: %v", err)
                }
                key.SetDWordValue("EnableDeadGWDetect", 0)
                key.Close()
@@ -411,7 +415,7 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
        // Create a list of network devices.
        devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName)
        if err != nil {
-               return false, false, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error())
+               return false, false, fmt.Errorf("SetupDiGetClassDevsEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err.Error())
        }
        defer devInfoList.Close()
 
@@ -443,13 +447,13 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
                        // Set class installer parameters for DIF_REMOVE.
                        err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
                        if err != nil {
-                               return false, false, errors.New("SetupDiSetClassInstallParams failed: " + err.Error())
+                               return false, false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
                        }
 
                        // Call appropriate class installer.
                        err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
                        if err != nil {
-                               return false, false, errors.New("SetupDiCallClassInstaller failed: " + err.Error())
+                               return false, false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
                        }
 
                        // Check if a system reboot is required. (Ignore errors)
@@ -495,7 +499,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
 func (wintun *Wintun) GetInterfaceName() (string, error) {
        key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
        if err != nil {
-               return "", errors.New("Network-specific registry key open failed: " + err.Error())
+               return "", fmt.Errorf("Network-specific registry key open failed: %v", err)
        }
        defer key.Close()
 
@@ -516,7 +520,7 @@ func (wintun *Wintun) SetInterfaceName(ifname string) error {
        // Set the interface name. The above line should have done this too, but in case it failed, we force it.
        key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
        if err != nil {
-               return errors.New("Network-specific registry key open failed: " + err.Error())
+               return fmt.Errorf("Network-specific registry key open failed: %v", err)
        }
        defer key.Close()
        return key.SetStringValue("Name", ifname)
@@ -526,32 +530,33 @@ func (wintun *Wintun) SetInterfaceName(ifname string) error {
 // GetNetRegKeyName returns interface-specific network registry key name.
 //
 func (wintun *Wintun) GetNetRegKeyName() string {
-       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(&wintun.CfgInstanceID))
+       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%s\\%s\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(&wintun.CfgInstanceID))
 }
 
 //
 // GetTcpipAdapterRegKeyName returns adapter-specific TCP/IP network registry key name.
 //
 func (wintun *Wintun) GetTcpipAdapterRegKeyName() string {
-       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid.ToString(&wintun.CfgInstanceID))
+       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%s", guid.ToString(&wintun.CfgInstanceID))
 }
 
 //
 // GetTcpipInterfaceRegKeyName returns interface-specific TCP/IP network registry key name.
 //
-func (wintun *Wintun) GetTcpipInterfaceRegKeyName() string {
+func (wintun *Wintun) GetTcpipInterfaceRegKeyName() (path string, err error) {
        key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetTcpipAdapterRegKeyName(), registry.QUERY_VALUE)
        if err != nil {
-               err = errors.New("Error opening adapter-specific TCP/IP network registry key: " + err.Error())
+               return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
        }
-       defer key.Close()
-
-       path, err := registryEx.GetStringValue(key, "IpConfig")
+       paths, _, err := key.GetStringsValue("IpConfig")
+       key.Close()
        if err != nil {
-               err = errors.New("Error reading IpConfig: " + err.Error())
+               return "", fmt.Errorf("Error reading IpConfig registry key: %v", err)
        }
-
-       return "SYSTEM\\CurrentControlSet\\Services\\" + path
+       if len(paths) == 0 {
+               return "", errors.New("No TCP/IP interfaces found on adapter")
+       }
+       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
 }
 
 //