]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: registry: revise value reading
authorSimon Rozman <simon@rozman.si>
Sat, 11 May 2019 04:21:02 +0000 (06:21 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 11 May 2019 15:14:37 +0000 (17:14 +0200)
- Make getStringValueRetry() reusable for reading any value type. This
  merges code from GetIntegerValueWait().
- expandString() >> toString() and extend to support REG_MULTI_SZ
  (to return first value of REG_MULTI_SZ). Furthermore, doing our own
  UTF-16 to UTF-8 conversion works around a bug in windows/registry's
  GetStringValue() non-zero terminated string handling.
- Provide toInteger() analogous to toString()
- GetStringValueWait() tolerates and reads REG_MULTI_SZ too now. It
  returns REG_MULTI_SZ[0], making GetFirstStringValueWait() redundant.

Signed-off-by: Simon Rozman <simon@rozman.si>
tun/wintun/registry/registry_windows.go
tun/wintun/wintun_windows.go

index 415aa009b535d6e6a3b3f2c2e076dff52f8acad7..b996c2359c6305197b594a90c08db52f5832ddf7 100644 (file)
@@ -10,7 +10,9 @@ import (
        "fmt"
        "runtime"
        "strings"
+       "syscall"
        "time"
+       "unsafe"
 
        "golang.org/x/sys/windows"
        "golang.org/x/sys/windows/registry"
@@ -102,18 +104,44 @@ func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
 }
 
 //
-// getStringValueRetry function reads a string value from registry. It waits for
+// getValue is the same as windows/registry's getValue, which is unfortunately
+// private.
+//
+func getValue(k registry.Key, name string, buf []byte) ([]byte, uint32, error) {
+       p, err := syscall.UTF16PtrFromString(name)
+       if err != nil {
+               return nil, 0, err
+       }
+       var t uint32
+       n := uint32(len(buf))
+       for {
+               err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
+               if err == nil {
+                       return buf[:n], t, nil
+               }
+               if err != syscall.ERROR_MORE_DATA {
+                       return nil, 0, err
+               }
+               if n <= uint32(len(buf)) {
+                       return nil, 0, err
+               }
+               buf = make([]byte, n)
+       }
+}
+
+//
+// getValueRetry function reads any value from registry. It waits for
 // the registry value to become available or returns error on timeout.
 //
 // Key must be opened with at least QUERY_VALUE|NOTIFY access.
 //
-func getStringValueRetry(key registry.Key, name string, timeout time.Duration, useFirstFromMulti bool) (string, uint32, error) {
+func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
        runtime.LockOSThread()
        defer runtime.UnlockOSThread()
 
        event, err := windows.CreateEvent(nil, 0, 0, nil)
        if err != nil {
-               return "", 0, fmt.Errorf("Error creating event: %v", err)
+               return nil, 0, fmt.Errorf("Error creating event: %v", err)
        }
        defer windows.CloseHandle(event)
 
@@ -121,46 +149,47 @@ func getStringValueRetry(key registry.Key, name string, timeout time.Duration, u
        for {
                err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
                if err != nil {
-                       return "", 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
+                       return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
                }
 
-               var value string
-               var values []string
-               var valueType uint32
-               if !useFirstFromMulti {
-                       value, valueType, err = key.GetStringValue(name)
-               } else {
-                       values, valueType, err = key.GetStringsValue(name)
-               }
-               if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND || (useFirstFromMulti && len(values) == 0) {
+               buf, valueType, err := getValue(key, name, buf)
+               if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
                        timeout := time.Until(deadline) / time.Millisecond
                        if timeout < 0 {
                                timeout = 0
                        }
                        s, err := windows.WaitForSingleObject(event, uint32(timeout))
                        if err != nil {
-                               return "", 0, fmt.Errorf("Unable to wait on registry value: %v", err)
+                               return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
                        }
                        if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
-                               return "", 0, errors.New("Timeout waiting for registry value")
+                               return nil, 0, errors.New("Timeout waiting for registry value")
                        }
                } else if err != nil {
-                       return "", 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
+                       return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
                } else {
-                       if !useFirstFromMulti {
-                               return value, valueType, nil
-                       } else {
-                               return values[0], registry.SZ, nil
-                       }
+                       return buf, valueType, nil
                }
        }
 }
 
-func expandString(value string, valueType uint32, err error) (string, error) {
+func toString(buf []byte, valueType uint32, err error) (string, error) {
        if err != nil {
                return "", err
        }
 
+       var value string
+       switch valueType {
+       case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
+               if len(buf) == 0 {
+                       return "", nil
+               }
+               value = syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
+
+       default:
+               return "", registry.ErrUnexpectedType
+       }
+
        if valueType != registry.EXPAND_SZ {
                // Value does not require expansion.
                return value, nil
@@ -176,6 +205,29 @@ func expandString(value string, valueType uint32, err error) (string, error) {
        return valueExp, nil
 }
 
+func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
+       if err != nil {
+               return 0, err
+       }
+
+       switch valueType {
+       case registry.DWORD:
+               if len(buf) != 4 {
+                       return 0, errors.New("DWORD value is not 4 bytes long")
+               }
+               return uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), nil
+
+       case registry.QWORD:
+               if len(buf) != 8 {
+                       return 0, errors.New("QWORD value is not 8 bytes long")
+               }
+               return uint64(*(*uint64)(unsafe.Pointer(&buf[0]))), nil
+
+       default:
+               return 0, registry.ErrUnexpectedType
+       }
+}
+
 //
 // GetStringValueWait function reads a string value from registry. It waits
 // for the registry value to become available or returns error on timeout.
@@ -185,15 +237,10 @@ func expandString(value string, valueType uint32, err error) (string, error) {
 // If the value type is REG_EXPAND_SZ the environment variables are expanded.
 // Should expanding fail, original string value and nil error are returned.
 //
-func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
-       return expandString(getStringValueRetry(key, name, timeout, false))
-}
-
-//
-// Same as GetStringValueWait, but returns the first from a MULTI_SZ.
+// If the value type is REG_MULTI_SZ only the first string is returned.
 //
-func GetFirstStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
-       return expandString(getStringValueRetry(key, name, timeout, true))
+func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
+       return toString(getValueRetry(key, name, make([]byte, 64), timeout))
 }
 
 //
@@ -204,8 +251,10 @@ func GetFirstStringValueWait(key registry.Key, name string, timeout time.Duratio
 // If the value type is REG_EXPAND_SZ the environment variables are expanded.
 // Should expanding fail, original string value and nil error are returned.
 //
+// If the value type is REG_MULTI_SZ only the first string is returned.
+//
 func GetStringValue(key registry.Key, name string) (string, error) {
-       return expandString(key.GetStringValue(name))
+       return toString(getValue(key, name, make([]byte, 64)))
 }
 
 //
@@ -216,39 +265,5 @@ func GetStringValue(key registry.Key, name string) (string, error) {
 // Key must be opened with at least QUERY_VALUE|NOTIFY access.
 //
 func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
-       runtime.LockOSThread()
-       defer runtime.UnlockOSThread()
-
-       event, err := windows.CreateEvent(nil, 0, 0, nil)
-       if err != nil {
-               return 0, fmt.Errorf("Error creating event: %v", err)
-       }
-       defer windows.CloseHandle(event)
-
-       deadline := time.Now().Add(timeout)
-       for {
-               err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
-               if err != nil {
-                       return 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
-               }
-
-               value, _, err := key.GetIntegerValue(name)
-               if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
-                       timeout := time.Until(deadline) / time.Millisecond
-                       if timeout < 0 {
-                               timeout = 0
-                       }
-                       s, err := windows.WaitForSingleObject(event, uint32(timeout))
-                       if err != nil {
-                               return 0, fmt.Errorf("Unable to wait on registry value: %v", err)
-                       }
-                       if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
-                               return 0, errors.New("Timeout waiting for registry value")
-                       }
-               } else if err != nil {
-                       return 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
-               } else {
-                       return value, nil
-               }
-       }
+       return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
 }
index 4dfc0bc1f2640620f0c97187ff4ed9b9f5beb486..2e32f64b701a481d33aeddf2360b975aa9db1f6c 100644 (file)
@@ -342,7 +342,7 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
                        wintun.GetTcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
                        waitForRegistryTimeout)
                if err == nil {
-                       _, err = registryEx.GetFirstStringValueWait(key, "IpConfig", waitForRegistryTimeout)
+                       _, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout)
                        key.Close()
                }
        }