]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: Refactor network registry key name generation
authorSimon Rozman <simon@rozman.si>
Thu, 7 Mar 2019 14:34:34 +0000 (15:34 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 7 Mar 2019 20:12:20 +0000 (21:12 +0100)
Signed-off-by: Simon Rozman <simon@rozman.si>
tun/wintun/wintun_windows.go

index 69fd30cd912f0c9c73195e23a054d736c86c4492..ab865a916b370b7487ab403180943cb962a97e52 100644 (file)
@@ -229,7 +229,7 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
        // Install interfaces if any.
        devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
 
-       var ifid *windows.GUID
+       var wintun *Wintun
        var rebootRequired bool
 
        // Install the device.
@@ -240,10 +240,10 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                        rebootRequired = true
                }
 
-               // Get network interface ID from registry. DIF_INSTALLDEVICE returns almost immediately,
-               // while the device installation continues in the background. It might take a while, before
-               // all registry keys and values are populated.
-               getInterfaceID := func() (*windows.GUID, error) {
+               // Get network interface. DIF_INSTALLDEVICE returns almost immediately, while the device
+               // installation continues in the background. It might take a while, before all registry
+               // keys and values are populated.
+               getInterface := func() (*Wintun, error) {
                        // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
                        keyDev, err := devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ)
                        if err != nil {
@@ -267,7 +267,8 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                                return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
                        }
 
-                       keyNetName := fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), value)
+                       wintun := (*Wintun)(ifid)
+                       keyNetName := wintun.GetNetRegKeyName()
                        keyNet, err := registry.OpenKey(registry.LOCAL_MACHINE, keyNetName, registry.QUERY_VALUE)
                        if err != nil {
                                if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
@@ -293,11 +294,11 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                                return nil, fmt.Errorf("Interface name registry value is not REG_SZ or REG_EXPAND_SZ (expected: %v or %v, provided: %v)", registry.SZ, registry.EXPAND_SZ, valueType)
                        }
 
-                       // TUN interface is ready. (As far as we need it.)
-                       return ifid, nil
+                       // TUN interface is ready. (As much as we need it.)
+                       return wintun, nil
                }
                for numAttempts := 0; numAttempts < 30; numAttempts++ {
-                       ifid, err = getInterfaceID()
+                       wintun, err = getInterface()
                        if err != nil {
                                if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
                                        // Wait and retry. TODO: Wait for a cancellable event instead.
@@ -311,7 +312,7 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
        }
 
        if err == nil {
-               return (*Wintun)(ifid), rebootRequired, nil
+               return wintun, rebootRequired, nil
        }
 
        // The interface failed to install, or the interface ID was unobtainable. Clean-up.
@@ -432,9 +433,9 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
 // GetInterfaceName returns network interface name.
 //
 func (wintun *Wintun) GetInterfaceName() (string, error) {
-       key, err := wintun.openNetRegKey(registry.QUERY_VALUE)
+       key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
        if err != nil {
-               return "", err
+               return "", errors.New("Network-specific registry key open failed: " + err.Error())
        }
        defer key.Close()
 
@@ -446,9 +447,9 @@ func (wintun *Wintun) GetInterfaceName() (string, error) {
 // SetInterfaceName sets network interface name.
 //
 func (wintun *Wintun) SetInterfaceName(ifname string) error {
-       key, err := wintun.openNetRegKey(registry.SET_VALUE)
+       key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
        if err != nil {
-               return err
+               return errors.New("Network-specific registry key open failed: " + err.Error())
        }
        defer key.Close()
 
@@ -457,16 +458,11 @@ func (wintun *Wintun) SetInterfaceName(ifname string) error {
 }
 
 //
-// openNetRegKey opens interface-specific network registry key.
+// GetNetRegKeyName returns interface-specific network registry key name.
 //
-func (wintun *Wintun) openNetRegKey(access uint32) (registry.Key, error) {
+func (wintun *Wintun) GetNetRegKeyName() string {
        ifid := (*windows.GUID)(wintun)
-       key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(ifid)), access)
-       if err != nil {
-               return 0, errors.New("Network-specific registry key open failed: " + err.Error())
-       }
-
-       return key, nil
+       return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(ifid))
 }
 
 //