]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Netlink sockets can't be shutdown
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 12:08:03 +0000 (14:08 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 12:08:03 +0000 (14:08 +0200)
conn_linux.go
main.go
tun_darwin.go
tun_linux.go

index 8d076ac2d7b26a580a73b964400b902506d4003c..e30631f617cc3053c1f78bf41ec35dd554b7bea3 100644 (file)
@@ -15,6 +15,7 @@
 package main
 
 import (
+       "./rwcancel"
        "errors"
        "golang.org/x/sys/unix"
        "net"
@@ -55,10 +56,11 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
 }
 
 type NativeBind struct {
-       sock4       int
-       sock6       int
-       netlinkSock int
-       lastMark    uint32
+       sock4         int
+       sock6         int
+       netlinkSock   int
+       netlinkCancel *rwcancel.RWCancel
+       lastMark      uint32
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
@@ -125,18 +127,23 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
        if err != nil {
                return nil, 0, err
        }
+       bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
+       if err != nil {
+               unix.Close(bind.netlinkSock)
+               return nil, 0, err
+       }
 
        go bind.routineRouteListener(device)
 
        bind.sock6, port, err = create6(port)
        if err != nil {
-               unix.Close(bind.netlinkSock)
+               bind.netlinkCancel.Cancel()
                return nil, port, err
        }
 
        bind.sock4, port, err = create4(port)
        if err != nil {
-               unix.Close(bind.netlinkSock)
+               bind.netlinkCancel.Cancel()
                unix.Close(bind.sock6)
        }
        return &bind, port, err
@@ -178,7 +185,8 @@ func closeUnblock(fd int) error {
 func (bind *NativeBind) Close() error {
        err1 := closeUnblock(bind.sock6)
        err2 := closeUnblock(bind.sock4)
-       err3 := closeUnblock(bind.netlinkSock)
+       err3 := bind.netlinkCancel.Cancel()
+
        if err1 != nil {
                return err1
        }
@@ -539,8 +547,20 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 func (bind *NativeBind) routineRouteListener(device *Device) {
        var reqPeer map[uint32]*Peer
 
+       defer unix.Close(bind.netlinkSock)
+
        for msg := make([]byte, 1<<16); ; {
-               msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+               var err error
+               var msgn int
+               for {
+                       msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+                       if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
+                               break
+                       }
+                       if !bind.netlinkCancel.ReadyRead() {
+                               return
+                       }
+               }
                if err != nil {
                        return
                }
diff --git a/main.go b/main.go
index c9ef343936762dd26e6b9b699e20fde685e79a63..6e876dff0d97998054b29d44be2f41b8b59e8182 100644 (file)
--- a/main.go
+++ b/main.go
@@ -221,14 +221,10 @@ func main() {
                return
        }
 
-       // create wireguard device
-
        device := NewDevice(tun, logger)
 
        logger.Info.Println("Device started")
 
-       // start uapi listener
-
        errs := make(chan error)
        term := make(chan os.Signal)
 
index ac8bffd642e03772f4c3063401055381e700e870..8f9a5d53d42b3de21c40da639cbe3f5e778cee4f 100644 (file)
@@ -122,11 +122,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
 
        _, err := tun.Name()
        if err != nil {
+               tun.fd.Close()
                return nil, err
        }
 
        tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
        if err != nil {
+               tun.fd.Close()
                return nil, err
        }
 
index 8e42d4407aaf1843a934e6bfbada1abceff18c04..32bd95d6f8a76a937103aab549ffca425ce4346c 100644 (file)
@@ -31,14 +31,16 @@ const (
 )
 
 type NativeTun struct {
-       fd                      *os.File
-       index                   int32         // if index
-       name                    string        // name of interface
-       errors                  chan error    // async error handling
-       events                  chan TUNEvent // device related events
-       nopi                    bool          // the device was pased IFF_NO_PI
-       rwcancel                *rwcancel.RWCancel
-       netlinkSock             int
+       fd            *os.File
+       fdCancel      *rwcancel.RWCancel
+       index         int32         // if index
+       name          string        // name of interface
+       errors        chan error    // async error handling
+       events        chan TUNEvent // device related events
+       nopi          bool          // the device was pased IFF_NO_PI
+       netlinkSock   int
+       netlinkCancel *rwcancel.RWCancel
+
        statusListenersShutdown chan struct{}
 }
 
@@ -86,9 +88,22 @@ func createNetlinkSocket() (int, error) {
 }
 
 func (tun *NativeTun) RoutineNetlinkListener() {
+       defer unix.Close(tun.netlinkSock)
+
        for msg := make([]byte, 1<<16); ; {
 
-               msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+               var err error
+               var msgn int
+               for {
+                       msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+                       if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
+                               break
+                       }
+                       if !tun.netlinkCancel.ReadyRead() {
+                               tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error())
+                               return
+                       }
+               }
                if err != nil {
                        tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
                        return
@@ -323,7 +338,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
                if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
                        return n, err
                }
-               if !tun.rwcancel.ReadyRead() {
+               if !tun.fdCancel.ReadyRead() {
                        return 0, errors.New("tun device closed")
                }
        }
@@ -334,10 +349,13 @@ func (tun *NativeTun) Events() chan TUNEvent {
 }
 
 func (tun *NativeTun) Close() error {
+       var err1 error
        close(tun.statusListenersShutdown)
-       err1 := closeUnblock(tun.netlinkSock)
+       if tun.netlinkCancel != nil {
+               err1 = tun.netlinkCancel.Cancel()
+       }
        err2 := tun.fd.Close()
-       err3 := tun.rwcancel.Cancel()
+       err3 := tun.fdCancel.Cancel()
        close(tun.events)
 
        if err1 != nil {
@@ -404,13 +422,15 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
        }
        var err error
 
-       tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
+       tun.fdCancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
        if err != nil {
+               tun.fd.Close()
                return nil, err
        }
 
        _, err = tun.Name()
        if err != nil {
+               tun.fd.Close()
                return nil, err
        }
 
@@ -423,6 +443,12 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
 
        tun.netlinkSock, err = createNetlinkSocket()
        if err != nil {
+               tun.fd.Close()
+               return nil, err
+       }
+       tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
+       if err != nil {
+               tun.fd.Close()
                return nil, err
        }