]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: move ring constants into module
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 29 Aug 2019 18:47:16 +0000 (12:47 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 29 Aug 2019 19:22:17 +0000 (13:22 -0600)
tun/tun_windows.go
tun/wintun/ring.go [new file with mode: 0644]
tun/wintun/wintun_windows.go

index 1891d2153c658c72df5aca877e6d6d13a966f37f..9c635b54c1983c8a7f33b968b55bb1b25c7f7cb0 100644 (file)
@@ -19,40 +19,11 @@ import (
 )
 
 const (
-       packetAlignment            = 4        // Number of bytes packets are aligned to in rings
-       packetSizeMax              = 0xffff   // Maximum packet size
-       packetCapacity             = 0x800000 // Ring capacity, 8MiB
-       packetTrailingSize         = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment
-       ioctlRegisterRings         = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
        rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
        spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
        spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
 )
 
-type packetHeader struct {
-       size uint32
-}
-
-type packet struct {
-       packetHeader
-       data [packetSizeMax]byte
-}
-
-type ring struct {
-       head      uint32
-       tail      uint32
-       alertable int32
-       data      [packetCapacity + packetTrailingSize]byte
-}
-
-type ringDescriptor struct {
-       send, receive struct {
-               size      uint32
-               ring      *ring
-               tailMoved windows.Handle
-       }
-}
-
 type rateJuggler struct {
        current       uint64
        nextByteCount uint64
@@ -64,7 +35,7 @@ type NativeTun struct {
        wt        *wintun.Interface
        handle    windows.Handle
        close     bool
-       rings     ringDescriptor
+       rings     wintun.RingDescriptor
        events    chan Event
        errors    chan error
        forcedMTU int
@@ -79,10 +50,6 @@ func procyield(cycles uint32)
 //go:linkname nanotime runtime.nanotime
 func nanotime() int64
 
-func packetAlign(size uint32) uint32 {
-       return (size + (packetAlignment - 1)) &^ (packetAlignment - 1)
-}
-
 //
 // CreateTUN creates a Wintun interface with the given name. Should a Wintun
 // interface with the same name exist, it is reused.
@@ -127,30 +94,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
                forcedMTU: 1500,
        }
 
-       tun.rings.send.size = uint32(unsafe.Sizeof(ring{}))
-       tun.rings.send.ring = &ring{}
-       tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
+       err = tun.rings.Init()
        if err != nil {
                tun.Close()
-               return nil, fmt.Errorf("Error creating event: %v", err)
+               return nil, fmt.Errorf("Error creating events: %v", err)
        }
 
-       tun.rings.receive.size = uint32(unsafe.Sizeof(ring{}))
-       tun.rings.receive.ring = &ring{}
-       tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
-       if err != nil {
-               tun.Close()
-               return nil, fmt.Errorf("Error creating event: %v", err)
-       }
-
-       tun.handle, err = tun.wt.Handle()
-       if err != nil {
-               tun.Close()
-               return nil, err
-       }
-
-       var bytesReturned uint32
-       err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil)
+       tun.handle, err = tun.wt.Register(&tun.rings)
        if err != nil {
                tun.Close()
                return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -172,18 +122,13 @@ func (tun *NativeTun) Events() chan Event {
 
 func (tun *NativeTun) Close() error {
        tun.close = true
-       if tun.rings.send.tailMoved != 0 {
-               windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping
+       if tun.rings.Send.TailMoved != 0 {
+               windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
        }
        if tun.handle != windows.InvalidHandle {
                windows.CloseHandle(tun.handle)
        }
-       if tun.rings.send.tailMoved != 0 {
-               windows.CloseHandle(tun.rings.send.tailMoved)
-       }
-       if tun.rings.send.tailMoved != 0 {
-               windows.CloseHandle(tun.rings.receive.tailMoved)
-       }
+       tun.rings.Close()
        var err error
        if tun.wt != nil {
                _, err = tun.wt.DeleteInterface()
@@ -214,8 +159,8 @@ retry:
                return 0, os.ErrClosed
        }
 
-       buffHead := atomic.LoadUint32(&tun.rings.send.ring.head)
-       if buffHead >= packetCapacity {
+       buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
+       if buffHead >= wintun.PacketCapacity {
                return 0, os.ErrClosed
        }
 
@@ -223,7 +168,7 @@ retry:
        shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
        var buffTail uint32
        for {
-               buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail)
+               buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
                if buffHead != buffTail {
                        break
                }
@@ -231,35 +176,35 @@ retry:
                        return 0, os.ErrClosed
                }
                if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
-                       windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE)
+                       windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
                        goto retry
                }
                procyield(1)
        }
-       if buffTail >= packetCapacity {
+       if buffTail >= wintun.PacketCapacity {
                return 0, os.ErrClosed
        }
 
-       buffContent := tun.rings.send.ring.wrap(buffTail - buffHead)
-       if buffContent < uint32(unsafe.Sizeof(packetHeader{})) {
+       buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
+       if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
                return 0, errors.New("incomplete packet header in send ring")
        }
 
-       packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead]))
-       if packet.size > packetSizeMax {
+       packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
+       if packet.Size > wintun.PacketSizeMax {
                return 0, errors.New("packet too big in send ring")
        }
 
-       alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size)
+       alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
        if alignedPacketSize > buffContent {
                return 0, errors.New("incomplete packet in send ring")
        }
 
-       copy(buff[offset:], packet.data[:packet.size])
-       buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize)
-       atomic.StoreUint32(&tun.rings.send.ring.head, buffHead)
-       tun.rate.update(uint64(packet.size))
-       return int(packet.size), nil
+       copy(buff[offset:], packet.Data[:packet.Size])
+       buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
+       atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
+       tun.rate.update(uint64(packet.Size))
+       return int(packet.Size), nil
 }
 
 func (tun *NativeTun) Flush() error {
@@ -273,29 +218,29 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 
        packetSize := uint32(len(buff) - offset)
        tun.rate.update(uint64(packetSize))
-       alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize)
+       alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
 
-       buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head)
-       if buffHead >= packetCapacity {
+       buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
+       if buffHead >= wintun.PacketCapacity {
                return 0, os.ErrClosed
        }
 
-       buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail)
-       if buffTail >= packetCapacity {
+       buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
+       if buffTail >= wintun.PacketCapacity {
                return 0, os.ErrClosed
        }
 
-       buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment)
+       buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
        if alignedPacketSize > buffSpace {
                return 0, nil // Dropping when ring is full.
        }
 
-       packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail]))
-       packet.size = packetSize
-       copy(packet.data[:packetSize], buff[offset:])
-       atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize))
-       if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 {
-               windows.SetEvent(tun.rings.receive.tailMoved)
+       packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
+       packet.Size = packetSize
+       copy(packet.Data[:packetSize], buff[offset:])
+       atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
+       if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
+               windows.SetEvent(tun.rings.Receive.TailMoved)
        }
        return int(packetSize), nil
 }
@@ -305,11 +250,6 @@ func (tun *NativeTun) LUID() uint64 {
        return tun.wt.LUID()
 }
 
-// wrap returns value modulo ring capacity
-func (rb *ring) wrap(value uint32) uint32 {
-       return value & (packetCapacity - 1)
-}
-
 func (rate *rateJuggler) update(packetLen uint64) {
        now := nanotime()
        total := atomic.AddUint64(&rate.nextByteCount, packetLen)
diff --git a/tun/wintun/ring.go b/tun/wintun/ring.go
new file mode 100644 (file)
index 0000000..8f46bc9
--- /dev/null
@@ -0,0 +1,97 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wintun
+
+import (
+       "unsafe"
+
+       "golang.org/x/sys/windows"
+)
+
+const (
+       PacketAlignment    = 4        // Number of bytes packets are aligned to in rings
+       PacketSizeMax      = 0xffff   // Maximum packet size
+       PacketCapacity     = 0x800000 // Ring capacity, 8MiB
+       PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
+       ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
+)
+
+type PacketHeader struct {
+       Size uint32
+}
+
+type Packet struct {
+       PacketHeader
+       Data [PacketSizeMax]byte
+}
+
+type Ring struct {
+       Head      uint32
+       Tail      uint32
+       Alertable int32
+       Data      [PacketCapacity + PacketTrailingSize]byte
+}
+
+type RingDescriptor struct {
+       Send, Receive struct {
+               Size      uint32
+               Ring      *Ring
+               TailMoved windows.Handle
+       }
+}
+
+// Wrap returns value modulo ring capacity
+func (rb *Ring) Wrap(value uint32) uint32 {
+       return value & (PacketCapacity - 1)
+}
+
+// Aligns a packet size to PacketAlignment
+func PacketAlign(size uint32) uint32 {
+       return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
+}
+
+func (descriptor *RingDescriptor) Init() (err error) {
+       descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
+       descriptor.Send.Ring = &Ring{}
+       descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
+       if err != nil {
+               return
+       }
+
+       descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
+       descriptor.Receive.Ring = &Ring{}
+       descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
+       if err != nil {
+               windows.CloseHandle(descriptor.Send.TailMoved)
+               return
+       }
+
+       return
+}
+
+func (descriptor *RingDescriptor) Close() {
+       if descriptor.Send.TailMoved != 0 {
+               windows.CloseHandle(descriptor.Send.TailMoved)
+               descriptor.Send.TailMoved = 0
+       }
+       if descriptor.Send.TailMoved != 0 {
+               windows.CloseHandle(descriptor.Receive.TailMoved)
+               descriptor.Receive.TailMoved = 0
+       }
+}
+
+func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
+       handle, err := wintun.handle()
+       if err != nil {
+               return 0, err
+       }
+       var bytesReturned uint32
+       err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
+       if err != nil {
+               return 0, err
+       }
+       return handle, nil
+}
index fb8b9088f991f44abfb840a90c4d952fed7e9359..e726748a9b1d2f5278a0287c54ac699c4a866d50 100644 (file)
@@ -698,8 +698,8 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData,
        return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
 }
 
-// Handle returns a handle to the interface device object.
-func (wintun *Interface) Handle() (windows.Handle, error) {
+// handle returns a handle to the interface device object.
+func (wintun *Interface) handle() (windows.Handle, error) {
        interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
        if err != nil {
                return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err)