]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: use new swdevice-based API for upcoming Wintun 0.14
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 12 Oct 2021 06:26:46 +0000 (00:26 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 12 Oct 2021 06:26:46 +0000 (00:26 -0600)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tun_windows.go
tun/wintun/wintun_windows.go

index ff16e2f7ca248bd31d4ee528384cb026f3a7ccc5..381a842667cce039b4c9f63f02645f38eb0434c0 100644 (file)
@@ -8,7 +8,6 @@ package tun
 import (
        "errors"
        "fmt"
-       "log"
        "os"
        "sync"
        "sync/atomic"
@@ -35,6 +34,7 @@ type rateJuggler struct {
 
 type NativeTun struct {
        wt        *wintun.Adapter
+       name      string
        handle    windows.Handle
        rate      rateJuggler
        session   wintun.Session
@@ -46,7 +46,7 @@ type NativeTun struct {
        forcedMTU int
 }
 
-var WintunPool, _ = wintun.MakePool("WireGuard")
+var WintunTunnelType = "WireGuard"
 var WintunStaticRequestedGUID *windows.GUID
 
 //go:linkname procyield runtime.procyield
@@ -68,25 +68,10 @@ func CreateTUN(ifname string, mtu int) (Device, error) {
 // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
 //
 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
-       var err error
-       var wt *wintun.Adapter
-
-       // Does an interface with this name already exist?
-       wt, err = WintunPool.OpenAdapter(ifname)
-       if err == nil {
-               // If so, we delete it, in case it has weird residual configuration.
-               _, err = wt.Delete(true)
-               if err != nil {
-                       return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
-               }
-       }
-       wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
+       wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
        if err != nil {
                return nil, fmt.Errorf("Error creating interface: %w", err)
        }
-       if rebootRequired {
-               log.Println("Windows indicated a reboot is required.")
-       }
 
        forcedMTU := 1420
        if mtu > 0 {
@@ -95,6 +80,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
 
        tun := &NativeTun{
                wt:        wt,
+               name:      ifname,
                handle:    windows.InvalidHandle,
                events:    make(chan Event, 10),
                forcedMTU: forcedMTU,
@@ -102,7 +88,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
 
        tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
        if err != nil {
-               tun.wt.Delete(false)
+               tun.wt.Close()
                close(tun.events)
                return nil, fmt.Errorf("Error starting session: %w", err)
        }
@@ -111,12 +97,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
 }
 
 func (tun *NativeTun) Name() (string, error) {
-       tun.running.Add(1)
-       defer tun.running.Done()
-       if atomic.LoadInt32(&tun.close) == 1 {
-               return "", os.ErrClosed
-       }
-       return tun.wt.Name()
+       return tun.name, nil
 }
 
 func (tun *NativeTun) File() *os.File {
@@ -135,7 +116,7 @@ func (tun *NativeTun) Close() error {
                tun.running.Wait()
                tun.session.End()
                if tun.wt != nil {
-                       _, err = tun.wt.Delete(false)
+                       tun.wt.Close()
                }
                close(tun.events)
        })
index 6c5a00d22ae2d5990faac848fb059609b36729d2..4edad917262bf39fdeb19ffa5be66bfef8f9f248 100644 (file)
@@ -6,7 +6,6 @@
 package wintun
 
 import (
-       "errors"
        "log"
        "runtime"
        "syscall"
@@ -23,175 +22,107 @@ const (
        logErr
 )
 
-const (
-       PoolNameMax    = 256
-       AdapterNameMax = 128
-)
+const AdapterNameMax = 128
 
-type Pool [PoolNameMax]uint16
 type Adapter struct {
        handle uintptr
 }
 
 var (
-       modwintun = newLazyDLL("wintun.dll", setupLogger)
-
+       modwintun                         = newLazyDLL("wintun.dll", setupLogger)
        procWintunCreateAdapter           = modwintun.NewProc("WintunCreateAdapter")
-       procWintunDeleteAdapter           = modwintun.NewProc("WintunDeleteAdapter")
-       procWintunDeletePoolDriver        = modwintun.NewProc("WintunDeletePoolDriver")
-       procWintunEnumAdapters            = modwintun.NewProc("WintunEnumAdapters")
-       procWintunFreeAdapter             = modwintun.NewProc("WintunFreeAdapter")
        procWintunOpenAdapter             = modwintun.NewProc("WintunOpenAdapter")
+       procWintunCloseAdapter            = modwintun.NewProc("WintunCloseAdapter")
+       procWintunDeleteDriver            = modwintun.NewProc("WintunDeleteDriver")
        procWintunGetAdapterLUID          = modwintun.NewProc("WintunGetAdapterLUID")
-       procWintunGetAdapterName          = modwintun.NewProc("WintunGetAdapterName")
        procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
-       procWintunSetAdapterName          = modwintun.NewProc("WintunSetAdapterName")
 )
 
-func setupLogger(dll *lazyDLL) {
-       syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
-               log.Println("[Wintun]", windows.UTF16PtrToString(msg))
-               return 0
-       }), 0, 0)
+type TimestampedWriter interface {
+       WriteWithTimestamp(p []byte, ts int64) (n int, err error)
 }
 
-func MakePool(poolName string) (pool *Pool, err error) {
-       poolName16, err := windows.UTF16FromString(poolName)
-       if err != nil {
-               return
-       }
-       if len(poolName16) > PoolNameMax {
-               err = errors.New("Pool name too long")
-               return
+func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
+       if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
+               tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
+       } else {
+               log.Println(windows.UTF16PtrToString(msg))
        }
-       pool = &Pool{}
-       copy(pool[:], poolName16)
-       return
+       return 0
 }
 
-func (pool *Pool) String() string {
-       return windows.UTF16ToString(pool[:])
+func setupLogger(dll *lazyDLL) {
+       var callback uintptr
+       if runtime.GOARCH == "386" || runtime.GOARCH == "arm" {
+               callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
+                       return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
+               })
+       } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
+               callback = windows.NewCallback(logMessage)
+       }
+       syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
 }
 
-func freeAdapter(wintun *Adapter) {
-       syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
+func closeAdapter(wintun *Adapter) {
+       syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
 }
 
-// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
-// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
-// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
-// released after use.
-func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
-       ifname16, err := windows.UTF16PtrFromString(ifname)
+// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
+// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
+// the GUID of the created network adapter, which then influences NLA generation
+// deterministically. If it is set to nil, the GUID is chosen by the system at random,
+// and hence a new NLA entry is created for each new adapter.
+func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
+       var name16 *uint16
+       name16, err = windows.UTF16PtrFromString(name)
        if err != nil {
-               return nil, err
-       }
-       r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
-       if r0 == 0 {
-               err = e1
                return
        }
-       wintun = &Adapter{r0}
-       runtime.SetFinalizer(wintun, freeAdapter)
-       return
-}
-
-// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while
-// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
-// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
-// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
-// uses is completely undocumented, and so there could be minor interesting complications with its
-// usage. This function returns the network adapter ID and a flag if reboot is required.
-func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
-       var ifname16 *uint16
-       ifname16, err = windows.UTF16PtrFromString(ifname)
+       var tunnelType16 *uint16
+       tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
        if err != nil {
                return
        }
-       var _p0 uint32
-       r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
-       rebootRequired = _p0 != 0
+       r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
        if r0 == 0 {
                err = e1
                return
        }
-       wintun = &Adapter{r0}
-       runtime.SetFinalizer(wintun, freeAdapter)
-       return
-}
-
-// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns
-// a bool indicating whether a reboot is required.
-func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) {
-       var _p0 uint32
-       if forceCloseSessions {
-               _p0 = 1
-       }
-       var _p1 uint32
-       r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
-       rebootRequired = _p1 != 0
-       if r1 == 0 {
-               err = e1
-       }
+       wintun = &Adapter{handle: r0}
+       runtime.SetFinalizer(wintun, closeAdapter)
        return
 }
 
-// DeleteMatchingAdapters deletes all Wintun adapters, which match
-// given criteria, and returns which ones it deleted, whether a reboot
-// is required after, and which errors occurred during the process.
-func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
-       cb := func(handle uintptr, _ uintptr) int {
-               adapter := &Adapter{handle}
-               if !matches(adapter) {
-                       return 1
-               }
-               rebootRequired2, err := adapter.Delete(forceCloseSessions)
-               if err != nil {
-                       errors = append(errors, err)
-                       return 1
-               }
-               rebootRequired = rebootRequired || rebootRequired2
-               return 1
-       }
-       r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
-       if r1 == 0 {
-               errors = append(errors, e1)
+// OpenAdapter opens an existing Wintun adapter by name.
+func OpenAdapter(name string) (wintun *Adapter, err error) {
+       var name16 *uint16
+       name16, err = windows.UTF16PtrFromString(name)
+       if err != nil {
+               return
        }
-       return
-}
-
-// Name returns the name of the Wintun adapter.
-func (wintun *Adapter) Name() (ifname string, err error) {
-       var ifname16 [AdapterNameMax]uint16
-       r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
-       if r1 == 0 {
+       r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
+       if r0 == 0 {
                err = e1
                return
        }
-       ifname = windows.UTF16ToString(ifname16[:])
+       wintun = &Adapter{handle: r0}
+       runtime.SetFinalizer(wintun, closeAdapter)
        return
 }
 
-// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
-// pools, also removes Wintun from the driver store, usually called by uninstallers.
-func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
-       var _p0 uint32
-       r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
-       rebootRequired = _p0 != 0
+// Close closes a Wintun adapter.
+func (wintun *Adapter) Close() (err error) {
+       runtime.SetFinalizer(wintun, nil)
+       r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
        if r1 == 0 {
                err = e1
        }
        return
-
 }
 
-// SetName sets name of the Wintun adapter.
-func (wintun *Adapter) SetName(ifname string) (err error) {
-       ifname16, err := windows.UTF16FromString(ifname)
-       if err != nil {
-               return err
-       }
-       r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
+// Uninstall removes the driver from the system if no drivers are currently in use.
+func Uninstall() (err error) {
+       r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
        if r1 == 0 {
                err = e1
        }