]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: windows: protect reads from closing
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 27 Apr 2021 02:22:45 +0000 (22:22 -0400)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 27 Apr 2021 02:22:45 +0000 (22:22 -0400)
The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tun_windows.go

index 9d83db738789194d28051d3608bcf95803a9f226..c8b8d39cb2445d78658ea7993b10e6910872701e 100644 (file)
@@ -37,8 +37,8 @@ type NativeTun struct {
        wt        *wintun.Adapter
        handle    windows.Handle
        close     bool
+       closing   sync.RWMutex
        events    chan Event
-       errors    chan error
        forcedMTU int
        rate      rateJuggler
        session   wintun.Session
@@ -97,7 +97,6 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
                wt:        wt,
                handle:    windows.InvalidHandle,
                events:    make(chan Event, 10),
-               errors:    make(chan error, 1),
                forcedMTU: forcedMTU,
        }
 
@@ -112,6 +111,11 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
 }
 
 func (tun *NativeTun) Name() (string, error) {
+       tun.closing.RLock()
+       defer tun.closing.RUnlock()
+       if tun.close {
+               return "", os.ErrClosed
+       }
        return tun.wt.Name()
 }
 
@@ -126,6 +130,8 @@ func (tun *NativeTun) Events() chan Event {
 func (tun *NativeTun) Close() error {
        var err error
        tun.closeOnce.Do(func() {
+               tun.closing.Lock()
+               defer tun.closing.Unlock()
                tun.close = true
                tun.session.End()
                if tun.wt != nil {
@@ -148,11 +154,11 @@ func (tun *NativeTun) ForceMTU(mtu int) {
 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
 
 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+       tun.closing.RLock()
+       defer tun.closing.RUnlock()
 retry:
-       select {
-       case err := <-tun.errors:
-               return 0, err
-       default:
+       if tun.close {
+               return 0, os.ErrClosed
        }
        start := nanotime()
        shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
@@ -189,6 +195,8 @@ func (tun *NativeTun) Flush() error {
 }
 
 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
+       tun.closing.RLock()
+       defer tun.closing.RUnlock()
        if tun.close {
                return 0, os.ErrClosed
        }
@@ -213,6 +221,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 
 // LUID returns Windows interface instance ID.
 func (tun *NativeTun) LUID() uint64 {
+       tun.closing.RLock()
+       defer tun.closing.RUnlock()
+       if tun.close {
+               return 0
+       }
        return tun.wt.LUID()
 }