]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: Use native Win32 API for I/O
authorSimon Rozman <simon@rozman.si>
Wed, 20 Mar 2019 20:45:40 +0000 (21:45 +0100)
committerSimon Rozman <simon@rozman.si>
Wed, 20 Mar 2019 23:56:45 +0000 (00:56 +0100)
Signed-off-by: Simon Rozman <simon@rozman.si>
tun/mksyscall.go [new file with mode: 0644]
tun/tun.go
tun/tun_default.go [new file with mode: 0644]
tun/tun_windows.go
tun/ztun_windows.go [new file with mode: 0644]

diff --git a/tun/mksyscall.go b/tun/mksyscall.go
new file mode 100644 (file)
index 0000000..06bb41e
--- /dev/null
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output ztun_windows.go tun_windows.go
index f38ee3156f6b943af0cc32fe04a8bd25797c880d..c4b6cacec97e88cb19be46635ed4202ea12fcbd0 100644 (file)
@@ -6,7 +6,6 @@
 package tun
 
 import (
-       "fmt"
        "os"
 )
 
@@ -27,15 +26,3 @@ type TUNDevice interface {
        Events() chan TUNEvent          // returns a constant channel of events related to the device
        Close() error                   // stops the device and closes the event channel
 }
-
-func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) {
-       sysconn, err := tun.tunFile.SyscallConn()
-       if err != nil {
-               tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
-               return
-       }
-       err = sysconn.Control(fn)
-       if err != nil {
-               tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
-       }
-}
diff --git a/tun/tun_default.go b/tun/tun_default.go
new file mode 100644 (file)
index 0000000..31747a2
--- /dev/null
@@ -0,0 +1,24 @@
+// +build !windows
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+       "fmt"
+)
+
+func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) {
+       sysconn, err := tun.tunFile.SyscallConn()
+       if err != nil {
+               tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
+               return
+       }
+       err = sysconn.Control(fn)
+       if err != nil {
+               tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
+       }
+}
index 81ba18ec6c350cd43824bf19d92a742d5559a1b5..15d9ae28526b34dbb918f63158ca578a675327c8 100644 (file)
@@ -9,6 +9,8 @@ import (
        "errors"
        "os"
        "sync"
+       "syscall"
+       "time"
        "unsafe"
 
        "golang.org/x/sys/windows"
@@ -20,6 +22,8 @@ const (
        packetExchangeAlignment uint32 = 16                               // Number of bytes packets are aligned to in exchange buffers
        packetSizeMax           uint32 = 0xf000 - packetExchangeAlignment // Maximum packet size
        packetExchangeSize      uint32 = 0x100000                         // Exchange buffer size (defaults to 1MiB)
+       retryRate                      = 4                                // Number of retries per second to reopen device pipe
+       retryTimeout                   = 5                                // Number of seconds to tolerate adapter unavailable
 )
 
 type exchgBufRead struct {
@@ -36,9 +40,10 @@ type exchgBufWrite struct {
 
 type NativeTun struct {
        wt        *wintun.Wintun
-       tunName   string
-       tunFile   *os.File
+       tunName   *uint16
+       tunFile   windows.Handle
        tunLock   sync.Mutex
+       close     bool
        rdBuff    *exchgBufRead
        wrBuff    *exchgBufWrite
        events    chan TUNEvent
@@ -46,6 +51,8 @@ type NativeTun struct {
        forcedMtu int
 }
 
+//sys  getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) = kernel32.GetOverlappedResult
+
 func packetAlign(size uint32) uint32 {
        return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1)
 }
@@ -83,9 +90,16 @@ func CreateTUN(ifname string) (TUNDevice, error) {
                return nil, errors.New("Flushing interface failed: " + err.Error())
        }
 
+       tunNameUTF16, err := windows.UTF16PtrFromString(wt.DataFileName())
+       if err != nil {
+               wt.DeleteInterface(0)
+               return nil, err
+       }
+
        return &NativeTun{
                wt:        wt,
-               tunName:   wt.DataFileName(),
+               tunName:   tunNameUTF16,
+               tunFile:   windows.InvalidHandle,
                rdBuff:    &exchgBufRead{},
                wrBuff:    &exchgBufWrite{},
                events:    make(chan TUNEvent, 10),
@@ -94,42 +108,67 @@ func CreateTUN(ifname string) (TUNDevice, error) {
        }, nil
 }
 
-func (tun *NativeTun) openTUN() {
+func (tun *NativeTun) openTUN() error {
+       retries := retryTimeout * retryRate
        for {
-               file, err := os.OpenFile(tun.tunName, os.O_RDWR, 0)
+               if tun.close {
+                       return errors.New("Cancelled")
+               }
+
+               file, err := windows.CreateFile(tun.tunName, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED|0x20000000 /*windows.FILE_FLAG_NO_BUFFERING*/, 0)
                if err != nil {
-                       continue
+                       if retries > 0 {
+                               time.Sleep(time.Second / retryRate)
+                               retries--
+                               continue
+                       }
+                       return err
                }
+
                tun.tunFile = file
+               return nil
        }
 }
 
 func (tun *NativeTun) closeTUN() (err error) {
-       if tun.tunFile != nil {
+       if tun.tunFile != windows.InvalidHandle {
                tun.tunLock.Lock()
                defer tun.tunLock.Unlock()
-               if tun.tunFile == nil {
+               if tun.tunFile == windows.InvalidHandle {
                        return
                }
                t := tun.tunFile
-               tun.tunFile = nil
-               err = t.Close()
+               tun.tunFile = windows.InvalidHandle
+               err = windows.CloseHandle(t)
        }
        return
 }
 
-func (tun *NativeTun) getTUN() (*os.File, error) {
-       if tun.tunFile == nil {
+func (tun *NativeTun) getTUN() (windows.Handle, error) {
+       if tun.tunFile == windows.InvalidHandle {
                tun.tunLock.Lock()
                defer tun.tunLock.Unlock()
-               if tun.tunFile != nil {
+               if tun.tunFile != windows.InvalidHandle {
                        return tun.tunFile, nil
                }
-               tun.openTUN()
+               err := tun.openTUN()
+               if err != nil {
+                       return windows.InvalidHandle, err
+               }
        }
        return tun.tunFile, nil
 }
 
+func (tun *NativeTun) isIOCancelled(err error) bool {
+       // Read&WriteFile() return the same ERROR_OPERATION_ABORTED if we close the handle
+       // or the TUN device is put down. We need a "close" flag to distinguish.
+       en, ok := err.(syscall.Errno)
+       if tun.close && ok && en == windows.ERROR_OPERATION_ABORTED {
+               return true
+       }
+       return false
+}
+
 func (tun *NativeTun) Name() (string, error) {
        return tun.wt.GetInterfaceName()
 }
@@ -143,6 +182,7 @@ func (tun *NativeTun) Events() chan TUNEvent {
 }
 
 func (tun *NativeTun) Close() error {
+       tun.close = true
        err1 := tun.closeTUN()
 
        if tun.events != nil {
@@ -199,15 +239,21 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
                }
 
                // Fill queue.
-               n, err := file.Read(tun.rdBuff.data[:])
+               var n uint32
+               overlapped := &windows.Overlapped{}
+               err = windows.ReadFile(file, tun.rdBuff.data[:], &n, overlapped)
                if err != nil {
-                       if pe, ok := err.(*os.PathError); ok && pe.Err == os.ErrClosed {
-                               return 0, err
+                       if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING {
+                               err = getOverlappedResult(file, overlapped, &n, true)
+                       }
+                       if err != nil {
+                               tun.rdBuff.avail = 0
+                               if tun.isIOCancelled(err) {
+                                       return 0, err
+                               }
+                               tun.closeTUN()
+                               continue
                        }
-                       // TUN interface stopped, failed, etc. Retry.
-                       tun.rdBuff.avail = 0
-                       tun.closeTUN()
-                       continue
                }
                tun.rdBuff.offset = 0
                tun.rdBuff.avail = uint32(n)
@@ -224,13 +270,22 @@ func (tun *NativeTun) flush() error {
        }
 
        // Flush write buffer.
-       _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset])
+       var n uint32
+       overlapped := &windows.Overlapped{}
+       err = windows.WriteFile(file, tun.wrBuff.data[:tun.wrBuff.offset], &n, overlapped)
        tun.wrBuff.packetNum = 0
        tun.wrBuff.offset = 0
        if err != nil {
-               // TUN interface stopped, failed, etc. Drop.
-               tun.closeTUN()
-               return err
+               if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING {
+                       err = getOverlappedResult(file, overlapped, &n, true)
+               }
+               if err != nil {
+                       if tun.isIOCancelled(err) {
+                               return err
+                       }
+                       tun.closeTUN()
+                       return nil
+               }
        }
 
        return nil
diff --git a/tun/ztun_windows.go b/tun/ztun_windows.go
new file mode 100644 (file)
index 0000000..ed779c1
--- /dev/null
@@ -0,0 +1,61 @@
+// Code generated by 'go generate'; DO NOT EDIT.
+
+package tun
+
+import (
+       "syscall"
+       "unsafe"
+
+       "golang.org/x/sys/windows"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+       errnoERROR_IO_PENDING = 997
+)
+
+var (
+       errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+       switch e {
+       case 0:
+               return nil
+       case errnoERROR_IO_PENDING:
+               return errERROR_IO_PENDING
+       }
+       // TODO: add more here, after collecting data on the common
+       // error values see on Windows. (perhaps when running
+       // all.bat?)
+       return e
+}
+
+var (
+       modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+
+       procGetOverlappedResult = modkernel32.NewProc("GetOverlappedResult")
+)
+
+func getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) {
+       var _p0 uint32
+       if wait {
+               _p0 = 1
+       } else {
+               _p0 = 0
+       }
+       r1, _, e1 := syscall.Syscall6(procGetOverlappedResult.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(done)), uintptr(_p0), 0, 0)
+       if r1 == 0 {
+               if e1 != 0 {
+                       err = errnoErr(e1)
+               } else {
+                       err = syscall.EINVAL
+               }
+       }
+       return
+}