]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: add more retry loops
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 31 Mar 2019 08:17:11 +0000 (10:17 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 1 Apr 2019 07:07:43 +0000 (09:07 +0200)
tun/tun_windows.go
tun/wintun/registryhacks_windows.go [new file with mode: 0644]
tun/wintun/wintun_windows.go

index 9428373a16381f2a35c1d0e090a754ad767089e1..948f08dfac54a63e1b7a19068b9ac2b2483d1ee4 100644 (file)
@@ -75,18 +75,11 @@ func CreateTUN(ifname string) (TUNDevice, error) {
                return nil, err
        }
 
-       go func() {
-               retries := retryTimeout * retryRate
-               for {
-                       err := wt.SetInterfaceName(ifname)
-                       if err != nil && retries > 0 {
-                               time.Sleep(time.Second / retryRate)
-                               retries--
-                               continue
-                       }
-                       return
-               }
-       }()
+       err = wt.SetInterfaceName(ifname)
+       if err != nil {
+               wt.DeleteInterface(0)
+               return nil, err
+       }
 
        err = wt.FlushInterface()
        if err != nil {
diff --git a/tun/wintun/registryhacks_windows.go b/tun/wintun/registryhacks_windows.go
new file mode 100644 (file)
index 0000000..62a629a
--- /dev/null
@@ -0,0 +1,42 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wintun
+
+import (
+       "golang.org/x/sys/windows/registry"
+       "time"
+)
+
+const (
+       numRetries = 25
+       retryTimeout = 100 * time.Millisecond
+)
+
+func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
+       for i := 0; i < numRetries; i++ {
+               key, err = registry.OpenKey(k, path, access)
+               if err == nil {
+                       break
+               }
+               if i != numRetries - 1 {
+                       time.Sleep(retryTimeout)
+               }
+       }
+       return
+}
+
+func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
+       for i := 0; i < numRetries; i++ {
+               val, valtype, err = k.GetStringValue(name)
+               if err == nil {
+                       break
+               }
+               if i != numRetries - 1 {
+                       time.Sleep(retryTimeout)
+               }
+       }
+       return
+}
index ba94b11da1ee7ec170b67729660c96f3ed11ef36..77e83a0b3356017ff9e9cfe7f77b20d276793a6b 100644 (file)
@@ -48,22 +48,14 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
        var valueStr string
        var valueType uint32
 
-       //TODO: Figure out a way to not need to loop like this.
-       for i := 0; i < 30; i++ {
-               // Read the NetCfgInstanceId value.
-               valueStr, valueType, err = key.GetStringValue("NetCfgInstanceId")
-               if err != nil {
-                       time.Sleep(time.Millisecond * 100)
-                       continue
-               }
-               if valueType != registry.SZ {
-                       return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
-               }
-               break
-       }
+       // Read the NetCfgInstanceId value.
+       valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId")
        if err != nil {
                return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
        }
+       if valueType != registry.SZ {
+               return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
+       }
 
        // Convert to windows.GUID.
        ifid, err := guid.FromString(valueStr)
@@ -117,7 +109,6 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
        // "foobar" would cause conflict with "FooBar".
        ifname = strings.ToLower(ifname)
 
-       // Iterate.
        for index := 0; ; index++ {
                // Get the device from the list. Should anything be wrong with this device, continue with next.
                deviceData, err := devInfoList.EnumDeviceInfo(index)
@@ -174,7 +165,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
                        }
 
                        // This interface is not using Wintun driver.
-                       return wintun, errors.New("Foreign network interface with the same name exists")
+                       return nil, errors.New("Foreign network interface with the same name exists")
                }
        }
 
@@ -444,7 +435,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
 // GetInterfaceName returns network interface name.
 //
 func (wintun *Wintun) GetInterfaceName() (string, error) {
-       key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
+       key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
        if err != nil {
                return "", errors.New("Network-specific registry key open failed: " + err.Error())
        }
@@ -458,7 +449,7 @@ func (wintun *Wintun) GetInterfaceName() (string, error) {
 // SetInterfaceName sets network interface name.
 //
 func (wintun *Wintun) SetInterfaceName(ifname string) error {
-       key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
+       key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
        if err != nil {
                return errors.New("Network-specific registry key open failed: " + err.Error())
        }
@@ -483,7 +474,7 @@ func (wintun *Wintun) GetNetRegKeyName() string {
 //
 func getRegStringValue(key registry.Key, name string) (string, error) {
        // Read string value.
-       value, valueType, err := key.GetStringValue(name)
+       value, valueType, err := keyGetStringValueRetry(key, name)
        if err != nil {
                return "", err
        }