]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: windows: just open two file handles
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 21 Mar 2019 21:20:09 +0000 (15:20 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 21 Mar 2019 21:20:09 +0000 (15:20 -0600)
tun/mksyscall.go [deleted file]
tun/tun_windows.go
tun/ztun_windows.go [deleted file]

diff --git a/tun/mksyscall.go b/tun/mksyscall.go
deleted file mode 100644 (file)
index 06bb41e..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package tun
-
-//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output ztun_windows.go tun_windows.go
index fffd802f57983b0269a85625436124906300767d..ea244f111a21f24eac5a9f627fd7eb6fbf2c7723 100644 (file)
@@ -9,7 +9,6 @@ import (
        "errors"
        "os"
        "sync"
-       "syscall"
        "time"
        "unsafe"
 
@@ -39,22 +38,18 @@ type exchgBufWrite struct {
 }
 
 type NativeTun struct {
-       wt        *wintun.Wintun
-       tunName   *uint16
-       tunFile   windows.Handle
-       tunLock   sync.Mutex
-       close     bool
-       rdBuff    *exchgBufRead
-       wrBuff    *exchgBufWrite
-       rdEvent   windows.Handle
-       wrEvent   windows.Handle
-       events    chan TUNEvent
-       errors    chan error
-       forcedMtu int
+       wt           *wintun.Wintun
+       tunFileRead  *os.File
+       tunFileWrite *os.File
+       tunLock      sync.Mutex
+       close        bool
+       rdBuff       *exchgBufRead
+       wrBuff       *exchgBufWrite
+       events       chan TUNEvent
+       errors       chan error
+       forcedMtu    int
 }
 
-//sys  getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) = kernel32.GetOverlappedResult
-
 func packetAlign(size uint32) uint32 {
        return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1)
 }
@@ -92,32 +87,10 @@ func CreateTUN(ifname string) (TUNDevice, error) {
                return nil, errors.New("Flushing interface failed: " + err.Error())
        }
 
-       tunNameUTF16, err := windows.UTF16PtrFromString(wt.DataFileName())
-       if err != nil {
-               wt.DeleteInterface(0)
-               return nil, err
-       }
-
-       rde, err := windows.CreateEvent(nil, 1 /*TRUE*/, 0 /*FALSE*/, nil)
-       if err != nil {
-               wt.DeleteInterface(0)
-               return nil, err
-       }
-       wre, err := windows.CreateEvent(nil, 1 /*TRUE*/, 0 /*FALSE*/, nil)
-       if err != nil {
-               windows.CloseHandle(rde)
-               wt.DeleteInterface(0)
-               return nil, err
-       }
-
        return &NativeTun{
                wt:        wt,
-               tunName:   tunNameUTF16,
-               tunFile:   windows.InvalidHandle,
                rdBuff:    &exchgBufRead{},
                wrBuff:    &exchgBufWrite{},
-               rdEvent:   rde,
-               wrEvent:   wre,
                events:    make(chan TUNEvent, 10),
                errors:    make(chan error, 1),
                forcedMtu: 1500,
@@ -126,12 +99,25 @@ func CreateTUN(ifname string) (TUNDevice, error) {
 
 func (tun *NativeTun) openTUN() error {
        retries := retryTimeout * retryRate
-       for {
-               if tun.close {
-                       return errors.New("Cancelled")
-               }
+       if tun.close {
+               return os.ErrClosed
+       }
 
-               file, err := windows.CreateFile(tun.tunName, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED|windows.FILE_FLAG_NO_BUFFERING, 0)
+       var err error
+       name := tun.wt.DataFileName()
+       for tun.tunFileRead == nil {
+               tun.tunFileRead, err = os.OpenFile(name, os.O_RDONLY, 0)
+               if err != nil {
+                       if retries > 0 && !tun.close {
+                               time.Sleep(time.Second / retryRate)
+                               retries--
+                               continue
+                       }
+                       return err
+               }
+       }
+       for tun.tunFileWrite == nil {
+               tun.tunFileWrite, err = os.OpenFile(name, os.O_WRONLY, 0)
                if err != nil {
                        if retries > 0 {
                                time.Sleep(time.Second / retryRate)
@@ -140,47 +126,69 @@ func (tun *NativeTun) openTUN() error {
                        }
                        return err
                }
-
-               tun.tunFile = file
-               return nil
        }
+       return nil
 }
 
 func (tun *NativeTun) closeTUN() (err error) {
-       if tun.tunFile != windows.InvalidHandle {
+       for tun.tunFileRead != nil {
                tun.tunLock.Lock()
-               defer tun.tunLock.Unlock()
-               if tun.tunFile == windows.InvalidHandle {
-                       return
+               if tun.tunFileRead == nil {
+                       tun.tunLock.Unlock()
+                       break
                }
-               t := tun.tunFile
-               tun.tunFile = windows.InvalidHandle
-               err = windows.CloseHandle(t)
+               t := tun.tunFileRead
+               tun.tunFileRead = nil
+               err = t.Close()
+               tun.tunLock.Unlock()
+               break
+       }
+       for tun.tunFileWrite != nil {
+               tun.tunLock.Lock()
+               if tun.tunFileWrite == nil {
+                       tun.tunLock.Unlock()
+                       break
+               }
+               t := tun.tunFileWrite
+               tun.tunFileWrite = nil
+               err2 := t.Close()
+               tun.tunLock.Unlock()
+               if err == nil {
+                       err = err2
+               }
+               break
        }
        return
 }
 
-func (tun *NativeTun) getTUN() (windows.Handle, error) {
-       if tun.tunFile == windows.InvalidHandle {
+func (tun *NativeTun) getTUN() (read *os.File, write *os.File, err error) {
+       read, write = tun.tunFileRead, tun.tunFileWrite
+       if read == nil || write == nil {
+               read, write = nil, nil
                tun.tunLock.Lock()
-               defer tun.tunLock.Unlock()
-               if tun.tunFile != windows.InvalidHandle {
-                       return tun.tunFile, nil
+               if tun.tunFileRead != nil && tun.tunFileWrite != nil {
+                       read, write = tun.tunFileRead, tun.tunFileWrite
+                       tun.tunLock.Unlock()
+                       return
                }
-               err := tun.openTUN()
+               err = tun.closeTUN()
                if err != nil {
-                       return windows.InvalidHandle, err
+                       tun.tunLock.Unlock()
+                       return
                }
+               err = tun.openTUN()
+               if err == nil {
+                       read, write = tun.tunFileRead, tun.tunFileWrite
+               }
+               tun.tunLock.Unlock()
+               return
        }
-       return tun.tunFile, nil
+       return
 }
 
-func (tun *NativeTun) isIOCancelled(err error) bool {
-       // Read&WriteFile() return the same ERROR_OPERATION_ABORTED if we close the handle
-       // or the TUN device is put down. We need a "close" flag to distinguish.
-       en, ok := err.(syscall.Errno)
-       if tun.close && ok && en == windows.ERROR_OPERATION_ABORTED {
-               return true
+func (tun *NativeTun) shouldReopenHandle(err error) bool {
+       if pe, ok := err.(*os.PathError); ok && pe.Err == os.ErrClosed {
+               return !tun.close
        }
        return false
 }
@@ -210,9 +218,6 @@ func (tun *NativeTun) Close() error {
                err1 = err2
        }
 
-       windows.CloseHandle(tun.rdEvent)
-       windows.CloseHandle(tun.wrEvent)
-
        return err1
 }
 
@@ -252,27 +257,20 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
                }
 
                // Get TUN data pipe.
-               file, err := tun.getTUN()
+               file, _, err := tun.getTUN()
                if err != nil {
                        return 0, err
                }
 
                // Fill queue.
-               var n uint32
-               overlapped := &windows.Overlapped{HEvent: tun.rdEvent}
-               err = windows.ReadFile(file, tun.rdBuff.data[:], &n, overlapped)
+               n, err := file.Read(tun.rdBuff.data[:])
                if err != nil {
-                       if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING {
-                               err = getOverlappedResult(file, overlapped, &n, true)
-                       }
-                       if err != nil {
-                               tun.rdBuff.avail = 0
-                               if tun.isIOCancelled(err) {
-                                       return 0, err
-                               }
+                       tun.rdBuff.avail = 0
+                       if tun.shouldReopenHandle(err) {
                                tun.closeTUN()
                                continue
                        }
+                       return 0, err
                }
                tun.rdBuff.offset = 0
                tun.rdBuff.avail = uint32(n)
@@ -287,30 +285,22 @@ func (tun *NativeTun) Flush() error {
        }
 
        // Get TUN data pipe.
-       file, err := tun.getTUN()
+       _, file, err := tun.getTUN()
        if err != nil {
                return err
        }
 
        // Flush write buffer.
-       var n uint32
-       overlapped := &windows.Overlapped{HEvent: tun.wrEvent}
-       err = windows.WriteFile(file, tun.wrBuff.data[:tun.wrBuff.offset], &n, overlapped)
+       _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset])
        tun.wrBuff.packetNum = 0
        tun.wrBuff.offset = 0
        if err != nil {
-               if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING {
-                       err = getOverlappedResult(file, overlapped, &n, true)
-               }
-               if err != nil {
-                       if tun.isIOCancelled(err) {
-                               return err
-                       }
+               if tun.shouldReopenHandle(err) {
                        tun.closeTUN()
                        return nil
                }
+               return err
        }
-
        return nil
 }
 
diff --git a/tun/ztun_windows.go b/tun/ztun_windows.go
deleted file mode 100644 (file)
index ed779c1..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package tun
-
-import (
-       "syscall"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
-       errnoERROR_IO_PENDING = 997
-)
-
-var (
-       errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
-       switch e {
-       case 0:
-               return nil
-       case errnoERROR_IO_PENDING:
-               return errERROR_IO_PENDING
-       }
-       // TODO: add more here, after collecting data on the common
-       // error values see on Windows. (perhaps when running
-       // all.bat?)
-       return e
-}
-
-var (
-       modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
-
-       procGetOverlappedResult = modkernel32.NewProc("GetOverlappedResult")
-)
-
-func getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) {
-       var _p0 uint32
-       if wait {
-               _p0 = 1
-       } else {
-               _p0 = 0
-       }
-       r1, _, e1 := syscall.Syscall6(procGetOverlappedResult.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(done)), uintptr(_p0), 0, 0)
-       if r1 == 0 {
-               if e1 != 0 {
-                       err = errnoErr(e1)
-               } else {
-                       err = syscall.EINVAL
-               }
-       }
-       return
-}