]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: windows: set event before waiting
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 7 May 2021 07:26:24 +0000 (09:26 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 7 May 2021 07:26:24 +0000 (09:26 +0200)
In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tun_windows.go

index 2305b72f7664ca4a76f93ae0be485e0f786d4cc4..ff16e2f7ca248bd31d4ee528384cb026f3a7ccc5 100644 (file)
@@ -40,10 +40,10 @@ type NativeTun struct {
        session   wintun.Session
        readWait  windows.Handle
        events    chan Event
-       closing   sync.RWMutex
+       running   sync.WaitGroup
        closeOnce sync.Once
+       close     int32
        forcedMTU int
-       close     bool
 }
 
 var WintunPool, _ = wintun.MakePool("WireGuard")
@@ -111,9 +111,9 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
 }
 
 func (tun *NativeTun) Name() (string, error) {
-       tun.closing.RLock()
-       defer tun.closing.RUnlock()
-       if tun.close {
+       tun.running.Add(1)
+       defer tun.running.Done()
+       if atomic.LoadInt32(&tun.close) == 1 {
                return "", os.ErrClosed
        }
        return tun.wt.Name()
@@ -130,9 +130,9 @@ func (tun *NativeTun) Events() chan Event {
 func (tun *NativeTun) Close() error {
        var err error
        tun.closeOnce.Do(func() {
-               tun.closing.Lock()
-               defer tun.closing.Unlock()
-               tun.close = true
+               atomic.StoreInt32(&tun.close, 1)
+               windows.SetEvent(tun.readWait)
+               tun.running.Wait()
                tun.session.End()
                if tun.wt != nil {
                        _, err = tun.wt.Delete(false)
@@ -158,16 +158,16 @@ func (tun *NativeTun) ForceMTU(mtu int) {
 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
 
 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
-       tun.closing.RLock()
-       defer tun.closing.RUnlock()
+       tun.running.Add(1)
+       defer tun.running.Done()
 retry:
-       if tun.close {
+       if atomic.LoadInt32(&tun.close) == 1 {
                return 0, os.ErrClosed
        }
        start := nanotime()
        shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
        for {
-               if tun.close {
+               if atomic.LoadInt32(&tun.close) == 1 {
                        return 0, os.ErrClosed
                }
                packet, err := tun.session.ReceivePacket()
@@ -199,9 +199,9 @@ func (tun *NativeTun) Flush() error {
 }
 
 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-       tun.closing.RLock()
-       defer tun.closing.RUnlock()
-       if tun.close {
+       tun.running.Add(1)
+       defer tun.running.Done()
+       if atomic.LoadInt32(&tun.close) == 1 {
                return 0, os.ErrClosed
        }
 
@@ -225,9 +225,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 
 // LUID returns Windows interface instance ID.
 func (tun *NativeTun) LUID() uint64 {
-       tun.closing.RLock()
-       defer tun.closing.RUnlock()
-       if tun.close {
+       tun.running.Add(1)
+       defer tun.running.Done()
+       if atomic.LoadInt32(&tun.close) == 1 {
                return 0
        }
        return tun.wt.LUID()