]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: ring management moved to wintun.dll
authorSimon Rozman <simon@rozman.si>
Sat, 24 Oct 2020 20:40:46 +0000 (22:40 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 7 Nov 2020 14:20:49 +0000 (15:20 +0100)
Signed-off-by: Simon Rozman <simon@rozman.si>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tun_windows.go
tun/wintun/ring_windows.go [deleted file]
tun/wintun/session_windows.go [new file with mode: 0644]
tun/wintun/wintun_windows.go

index 684d6f0be20610dbcdcd4778adb7580c26a0c878..63eb812d787105d91582bde2321befa335142106 100644 (file)
@@ -9,10 +9,9 @@ import (
        "errors"
        "fmt"
        "os"
-       "sync"
        "sync/atomic"
        "time"
-       "unsafe"
+       "unsafe"
 
        "golang.org/x/sys/windows"
 
@@ -40,8 +39,8 @@ type NativeTun struct {
        errors    chan error
        forcedMTU int
        rate      rateJuggler
-       rings     *wintun.RingDescriptor
-       writeLock sync.Mutex
+       session   wintun.Session
+       readWait  windows.Handle
 }
 
 var WintunPool *wintun.Pool
@@ -103,17 +102,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
                forcedMTU: forcedMTU,
        }
 
-       tun.rings, err = wintun.NewRingDescriptor()
+       tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
        if err != nil {
-               tun.Close()
-               return nil, fmt.Errorf("Error creating events: %v", err)
-       }
-
-       tun.handle, err = tun.wt.Register(tun.rings)
-       if err != nil {
-               tun.Close()
-               return nil, fmt.Errorf("Error registering rings: %v", err)
+               _, err = tun.wt.Delete(false)
+               close(tun.events)
+               return nil, fmt.Errorf("Error starting session: %v", err)
        }
+       tun.readWait = tun.session.ReadWaitEvent()
        return tun, nil
 }
 
@@ -131,13 +126,7 @@ func (tun *NativeTun) Events() chan Event {
 
 func (tun *NativeTun) Close() error {
        tun.close = true
-       if tun.rings.Send.TailMoved != 0 {
-               windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
-       }
-       if tun.handle != windows.InvalidHandle {
-               windows.CloseHandle(tun.handle)
-       }
-       tun.rings.Close()
+       tun.session.End()
        var err error
        if tun.wt != nil {
                _, err = tun.wt.Delete(false)
@@ -164,56 +153,34 @@ retry:
                return 0, err
        default:
        }
-       if tun.close {
-               return 0, os.ErrClosed
-       }
-
-       buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
-       if buffHead >= wintun.PacketCapacity {
-               return 0, os.ErrClosed
-       }
-
        start := nanotime()
        shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
-       var buffTail uint32
        for {
-               buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
-               if buffHead != buffTail {
-                       break
-               }
                if tun.close {
                        return 0, os.ErrClosed
                }
-               if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
-                       windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
-                       goto retry
+               packet, err := tun.session.ReceivePacket()
+               switch err {
+               case nil:
+                       packetSize := len(packet)
+                       copy(buff[offset:], packet)
+                       tun.session.ReleaseReceivePacket(packet)
+                       tun.rate.update(uint64(packetSize))
+                       return packetSize, nil
+               case windows.ERROR_NO_MORE_ITEMS:
+                       if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
+                               windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
+                               goto retry
+                       }
+                       procyield(1)
+                       continue
+               case windows.ERROR_HANDLE_EOF:
+                       return 0, os.ErrClosed
+               case windows.ERROR_INVALID_DATA:
+                       return 0, errors.New("Send ring corrupt")
                }
-               procyield(1)
-       }
-       if buffTail >= wintun.PacketCapacity {
-               return 0, os.ErrClosed
-       }
-
-       buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
-       if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
-               return 0, errors.New("incomplete packet header in send ring")
-       }
-
-       packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
-       if packet.Size > wintun.PacketSizeMax {
-               return 0, errors.New("packet too big in send ring")
-       }
-
-       alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
-       if alignedPacketSize > buffContent {
-               return 0, errors.New("incomplete packet in send ring")
+               return 0, fmt.Errorf("Read failed: %v", err)
        }
-
-       copy(buff[offset:], packet.Data[:packet.Size])
-       buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
-       atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
-       tun.rate.update(uint64(packet.Size))
-       return int(packet.Size), nil
 }
 
 func (tun *NativeTun) Flush() error {
@@ -225,36 +192,22 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
                return 0, os.ErrClosed
        }
 
-       packetSize := uint32(len(buff) - offset)
+       packetSize := len(buff) - offset
        tun.rate.update(uint64(packetSize))
-       alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
 
-       tun.writeLock.Lock()
-       defer tun.writeLock.Unlock()
-
-       buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
-       if buffHead >= wintun.PacketCapacity {
-               return 0, os.ErrClosed
+       packet, err := tun.session.AllocateSendPacket(packetSize)
+       if err == nil {
+               copy(packet, buff[offset:])
+               tun.session.SendPacket(packet)
+               return packetSize, nil
        }
-
-       buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
-       if buffTail >= wintun.PacketCapacity {
+       switch err {
+       case windows.ERROR_HANDLE_EOF:
                return 0, os.ErrClosed
-       }
-
-       buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
-       if alignedPacketSize > buffSpace {
+       case windows.ERROR_BUFFER_OVERFLOW:
                return 0, nil // Dropping when ring is full.
        }
-
-       packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
-       packet.Size = packetSize
-       copy(packet.Data[:packetSize], buff[offset:])
-       atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
-       if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
-               windows.SetEvent(tun.rings.Receive.TailMoved)
-       }
-       return int(packetSize), nil
+       return 0, fmt.Errorf("Write failed: %v", err)
 }
 
 // LUID returns Windows interface instance ID.
diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go
deleted file mode 100644 (file)
index ed460fb..0000000
+++ /dev/null
@@ -1,117 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
-       "runtime"
-       "unsafe"
-
-       "golang.org/x/sys/windows"
-)
-
-const (
-       PacketAlignment    = 4        // Number of bytes packets are aligned to in rings
-       PacketSizeMax      = 0xffff   // Maximum packet size
-       PacketCapacity     = 0x800000 // Ring capacity, 8MiB
-       PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
-       ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
-)
-
-type PacketHeader struct {
-       Size uint32
-}
-
-type Packet struct {
-       PacketHeader
-       Data [PacketSizeMax]byte
-}
-
-type Ring struct {
-       Head      uint32
-       Tail      uint32
-       Alertable int32
-       Data      [PacketCapacity + PacketTrailingSize]byte
-}
-
-type RingDescriptor struct {
-       Send, Receive struct {
-               Size      uint32
-               Ring      *Ring
-               TailMoved windows.Handle
-       }
-}
-
-// Wrap returns value modulo ring capacity
-func (rb *Ring) Wrap(value uint32) uint32 {
-       return value & (PacketCapacity - 1)
-}
-
-// Aligns a packet size to PacketAlignment
-func PacketAlign(size uint32) uint32 {
-       return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
-}
-
-func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
-       descriptor = new(RingDescriptor)
-       allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
-       if err != nil {
-               return
-       }
-       defer func() {
-               if err != nil {
-                       descriptor.free()
-                       descriptor = nil
-               }
-       }()
-       descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
-       descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
-       descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
-       if err != nil {
-               return
-       }
-
-       descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
-       descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
-       descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
-       if err != nil {
-               windows.CloseHandle(descriptor.Send.TailMoved)
-               return
-       }
-       runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
-       return
-}
-
-func (descriptor *RingDescriptor) free() {
-       if descriptor.Send.Ring != nil {
-               windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
-               descriptor.Send.Ring = nil
-               descriptor.Receive.Ring = nil
-       }
-}
-
-func (descriptor *RingDescriptor) Close() {
-       if descriptor.Send.TailMoved != 0 {
-               windows.CloseHandle(descriptor.Send.TailMoved)
-               descriptor.Send.TailMoved = 0
-       }
-       if descriptor.Send.TailMoved != 0 {
-               windows.CloseHandle(descriptor.Receive.TailMoved)
-               descriptor.Receive.TailMoved = 0
-       }
-}
-
-func (wintun *Adapter) Register(descriptor *RingDescriptor) (windows.Handle, error) {
-       handle, err := wintun.OpenAdapterDeviceObject()
-       if err != nil {
-               return 0, err
-       }
-       var bytesReturned uint32
-       err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
-       if err != nil {
-               return 0, err
-       }
-       return handle, nil
-}
diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go
new file mode 100644 (file)
index 0000000..1619e5a
--- /dev/null
@@ -0,0 +1,108 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package wintun
+
+import (
+       "syscall"
+       "unsafe"
+
+       "golang.org/x/sys/windows"
+)
+
+type Session struct {
+       handle uintptr
+}
+
+const (
+       PacketSizeMax   = 0xffff    // Maximum packet size
+       RingCapacityMin = 0x20000   // Minimum ring capacity (128 kiB)
+       RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
+)
+
+// Packet with data
+type Packet struct {
+       Next *Packet              // Pointer to next packet in queue
+       Size uint32               // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
+       Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
+}
+
+var (
+       procWintunAllocateSendPacket   = modwintun.NewProc("WintunAllocateSendPacket").Addr()
+       procWintunEndSession           = modwintun.NewProc("WintunEndSession")
+       procWintunGetReadWaitEvent     = modwintun.NewProc("WintunGetReadWaitEvent")
+       procWintunReceivePacket        = modwintun.NewProc("WintunReceivePacket").Addr()
+       procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket").Addr()
+       procWintunSendPacket           = modwintun.NewProc("WintunSendPacket").Addr()
+       procWintunStartSession         = modwintun.NewProc("WintunStartSession")
+)
+
+func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
+       r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
+       if r0 == 0 {
+               err = e1
+       } else {
+               session = Session{r0}
+       }
+       return
+}
+
+func (session Session) End() {
+       syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
+       session.handle = 0
+}
+
+func (session Session) ReadWaitEvent() (handle windows.Handle) {
+       r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
+       handle = windows.Handle(r0)
+       return
+}
+
+func (session Session) ReceivePacket() (packet []byte, err error) {
+       var packetSize uint32
+       r0, _, e1 := syscall.Syscall(procWintunReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
+       if r0 == 0 {
+               err = e1
+               return
+       }
+       unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
+       return
+}
+
+func (session Session) ReleaseReceivePacket(packet []byte) {
+       syscall.Syscall(procWintunReleaseReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
+}
+
+func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
+       r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket, 2, session.handle, uintptr(packetSize), 0)
+       if r0 == 0 {
+               err = e1
+               return
+       }
+       unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
+       return
+}
+
+func (session Session) SendPacket(packet []byte) {
+       syscall.Syscall(procWintunSendPacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
+}
+
+// unsafeSlice updates the slice slicePtr to be a slice
+// referencing the provided data with its length & capacity set to
+// lenCap.
+//
+// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
+// update callers to use unsafe.Slice instead of this.
+func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
+       type sliceHeader struct {
+               Data unsafe.Pointer
+               Len  int
+               Cap  int
+       }
+       h := (*sliceHeader)(slicePtr)
+       h.Data = data
+       h.Len = lenCap
+       h.Cap = lenCap
+}
index ac3357992063cd613ecb95db1447f811e51a8867..e7ba8b6e2d282c140ca6521f340bfebc88f912a6 100644 (file)
@@ -45,7 +45,6 @@ var (
        procWintunGetAdapterLUID          = modwintun.NewProc("WintunGetAdapterLUID")
        procWintunGetAdapterName          = modwintun.NewProc("WintunGetAdapterName")
        procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
-       procWintunOpenAdapterDeviceObject = modwintun.NewProc("WintunOpenAdapterDeviceObject")
        procWintunSetAdapterName          = modwintun.NewProc("WintunSetAdapterName")
        procWintunSetLogger               = modwintun.NewProc("WintunSetLogger")
 )
@@ -210,16 +209,6 @@ func RunningVersion() (version uint32, err error) {
        return
 }
 
-// handle returns a handle to the adapter device object. Release handle with windows.CloseHandle
-func (wintun *Adapter) OpenAdapterDeviceObject() (handle windows.Handle, err error) {
-       r0, _, e1 := syscall.Syscall(procWintunOpenAdapterDeviceObject.Addr(), 1, uintptr(wintun.handle), 0, 0)
-       handle = windows.Handle(r0)
-       if handle == windows.InvalidHandle {
-               err = e1
-       }
-       return
-}
-
 // LUID returns the LUID of the adapter.
 func (wintun *Adapter) LUID() (luid uint64) {
        syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)