]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wintun: manage ring memory manually
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 21 Nov 2019 13:48:21 +0000 (14:48 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 22 Nov 2019 12:13:55 +0000 (13:13 +0100)
It's large and Go's garbage collector doesn't deal with it especially
well.

tun/tun_windows.go
tun/wintun/ring_windows.go

index daad4aa5bf9aa3d5652f9fbb6a3d0934c224aea8..8fc51742c74b53be317bd39eef9f58ad3bcc03fb 100644 (file)
@@ -35,11 +35,11 @@ type NativeTun struct {
        wt        *wintun.Interface
        handle    windows.Handle
        close     bool
-       rings     wintun.RingDescriptor
        events    chan Event
        errors    chan error
        forcedMTU int
        rate      rateJuggler
+       rings     *wintun.RingDescriptor
 }
 
 const WintunPool = wintun.Pool("WireGuard")
@@ -93,13 +93,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
                forcedMTU: forcedMTU,
        }
 
-       err = tun.rings.Init()
+       tun.rings, err = wintun.NewRingDescriptor()
        if err != nil {
                tun.Close()
                return nil, fmt.Errorf("Error creating events: %v", err)
        }
 
-       tun.handle, err = tun.wt.Register(&tun.rings)
+       tun.handle, err = tun.wt.Register(tun.rings)
        if err != nil {
                tun.Close()
                return nil, fmt.Errorf("Error registering rings: %v", err)
index 8f46bc979ec8af493a2857dc1c3456d7254ca581..8e6b37584d4e8a997a1e2cf65d341df18675a662 100644 (file)
@@ -6,6 +6,7 @@
 package wintun
 
 import (
+       "runtime"
        "unsafe"
 
        "golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
        return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
 }
 
-func (descriptor *RingDescriptor) Init() (err error) {
+func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
+       descriptor = new(RingDescriptor)
+       allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+       if err != nil {
+               return
+       }
+       defer func() {
+               if err != nil {
+                       descriptor.free()
+                       descriptor = nil
+               }
+       }()
        descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
-       descriptor.Send.Ring = &Ring{}
+       descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
        descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
        if err != nil {
                return
        }
 
        descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
-       descriptor.Receive.Ring = &Ring{}
+       descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
        descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
        if err != nil {
                windows.CloseHandle(descriptor.Send.TailMoved)
                return
        }
-
+       runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
        return
 }
 
+func (descriptor *RingDescriptor) free() {
+       if descriptor.Send.Ring != nil {
+               windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
+               descriptor.Send.Ring = nil
+               descriptor.Receive.Ring = nil
+       }
+}
+
 func (descriptor *RingDescriptor) Close() {
        if descriptor.Send.TailMoved != 0 {
                windows.CloseHandle(descriptor.Send.TailMoved)