]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: make NativeTun.Close well behaved, not crash on double close
authorBrad Fitzpatrick <bradfitz@tailscale.com>
Thu, 18 Feb 2021 22:53:22 +0000 (14:53 -0800)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 22 Feb 2021 14:26:29 +0000 (15:26 +0100)
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
tun/tun_darwin.go
tun/tun_freebsd.go
tun/tun_linux.go
tun/tun_openbsd.go
tun/tun_windows.go

index 542f666dbeb29d595204ecd535e451611904f620..a703c8c34eac044ef483692dd52984aa585576f1 100644 (file)
@@ -10,6 +10,7 @@ import (
        "fmt"
        "net"
        "os"
+       "sync"
        "syscall"
        "time"
        "unsafe"
@@ -26,6 +27,7 @@ type NativeTun struct {
        events      chan Event
        errors      chan error
        routeSocket int
+       closeOnce   sync.Once
 }
 
 func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
@@ -256,14 +258,16 @@ func (tun *NativeTun) Flush() error {
 }
 
 func (tun *NativeTun) Close() error {
-       var err2 error
-       err1 := tun.tunFile.Close()
-       if tun.routeSocket != -1 {
-               unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
-               err2 = unix.Close(tun.routeSocket)
-       } else if tun.events != nil {
-               close(tun.events)
-       }
+       var err1, err2 error
+       tun.closeOnce.Do(func() {
+               err1 = tun.tunFile.Close()
+               if tun.routeSocket != -1 {
+                       unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+                       err2 = unix.Close(tun.routeSocket)
+               } else if tun.events != nil {
+                       close(tun.events)
+               }
+       })
        if err1 != nil {
                return err1
        }
index e0dc2e17a1b75089fc689958731b9347c7048808..12b44da05e18e0d56558064081076d8e4b05d2c0 100644 (file)
@@ -11,6 +11,7 @@ import (
        "fmt"
        "net"
        "os"
+       "sync"
        "syscall"
        "unsafe"
 
@@ -82,6 +83,7 @@ type NativeTun struct {
        events      chan Event
        errors      chan error
        routeSocket int
+       closeOnce   sync.Once
 }
 
 func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -472,16 +474,18 @@ func (tun *NativeTun) Flush() error {
 }
 
 func (tun *NativeTun) Close() error {
-       var err3 error
-       err1 := tun.tunFile.Close()
-       err2 := tunDestroy(tun.name)
-       if tun.routeSocket != -1 {
-               unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
-               err3 = unix.Close(tun.routeSocket)
-               tun.routeSocket = -1
-       } else if tun.events != nil {
-               close(tun.events)
-       }
+       var err1, err2, err3 error
+       tun.closeOnce.Do(func() {
+               err1 = tun.tunFile.Close()
+               err2 = tunDestroy(tun.name)
+               if tun.routeSocket != -1 {
+                       unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+                       err3 = unix.Close(tun.routeSocket)
+                       tun.routeSocket = -1
+               } else if tun.events != nil {
+                       close(tun.events)
+               }
+       })
        if err1 != nil {
                return err1
        }
index 501f3a3108bd71495c488ba6d8fe77835d766ac1..e0c9878c160eabcf31e749d97d3a7296bb6819d1 100644 (file)
@@ -39,6 +39,8 @@ type NativeTun struct {
        hackListenerClosed      sync.Mutex
        statusListenersShutdown chan struct{}
 
+       closeOnce sync.Once
+
        nameOnce  sync.Once // guards calling initNameCache, which sets following fields
        nameCache string    // name of interface
        nameErr   error
@@ -372,17 +374,18 @@ func (tun *NativeTun) Events() chan Event {
 }
 
 func (tun *NativeTun) Close() error {
-       var err1 error
-       if tun.statusListenersShutdown != nil {
-               close(tun.statusListenersShutdown)
-               if tun.netlinkCancel != nil {
-                       err1 = tun.netlinkCancel.Cancel()
+       var err1, err2 error
+       tun.closeOnce.Do(func() {
+               if tun.statusListenersShutdown != nil {
+                       close(tun.statusListenersShutdown)
+                       if tun.netlinkCancel != nil {
+                               err1 = tun.netlinkCancel.Cancel()
+                       }
+               } else if tun.events != nil {
+                       close(tun.events)
                }
-       } else if tun.events != nil {
-               close(tun.events)
-       }
-       err2 := tun.tunFile.Close()
-
+               err2 = tun.tunFile.Close()
+       })
        if err1 != nil {
                return err1
        }
index 8fca1e34537696930e7ff9b7d18f399393c626a4..7ef62f4305a7ec76211a849acbb60d580237932f 100644 (file)
@@ -10,6 +10,7 @@ import (
        "fmt"
        "net"
        "os"
+       "sync"
        "syscall"
        "unsafe"
 
@@ -32,6 +33,7 @@ type NativeTun struct {
        events      chan Event
        errors      chan error
        routeSocket int
+       closeOnce   sync.Once
 }
 
 func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -245,15 +247,17 @@ func (tun *NativeTun) Flush() error {
 }
 
 func (tun *NativeTun) Close() error {
-       var err2 error
-       err1 := tun.tunFile.Close()
-       if tun.routeSocket != -1 {
-               unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
-               err2 = unix.Close(tun.routeSocket)
-               tun.routeSocket = -1
-       } else if tun.events != nil {
-               close(tun.events)
-       }
+       var err1, err2 error
+       tun.closeOnce.Do(func() {
+               err1 = tun.tunFile.Close()
+               if tun.routeSocket != -1 {
+                       unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+                       err2 = unix.Close(tun.routeSocket)
+                       tun.routeSocket = -1
+               } else if tun.events != nil {
+                       close(tun.events)
+               }
+       })
        if err1 != nil {
                return err1
        }
index 081b5e2e8a26cd33d6c34c8ad4c86947e9be965a..9d83db738789194d28051d3608bcf95803a9f226 100644 (file)
@@ -10,6 +10,7 @@ import (
        "fmt"
        "log"
        "os"
+       "sync"
        "sync/atomic"
        "time"
        _ "unsafe"
@@ -42,6 +43,7 @@ type NativeTun struct {
        rate      rateJuggler
        session   wintun.Session
        readWait  windows.Handle
+       closeOnce sync.Once
 }
 
 var WintunPool, _ = wintun.MakePool("WireGuard")
@@ -122,13 +124,15 @@ func (tun *NativeTun) Events() chan Event {
 }
 
 func (tun *NativeTun) Close() error {
-       tun.close = true
-       tun.session.End()
        var err error
-       if tun.wt != nil {
-               _, err = tun.wt.Delete(false)
-       }
-       close(tun.events)
+       tun.closeOnce.Do(func() {
+               tun.close = true
+               tun.session.End()
+               if tun.wt != nil {
+                       _, err = tun.wt.Delete(false)
+               }
+               close(tun.events)
+       })
        return err
 }