]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: linux: work out netpoll trick
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 7 Mar 2019 00:51:41 +0000 (01:51 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 7 Mar 2019 00:51:41 +0000 (01:51 +0100)
tun/tun_linux.go

index c352c1a1e61e0be157d03580ec1250c8848516ca..b7c429c5152fd3a8e40801c55cd863f3cd3a4ef1 100644 (file)
@@ -17,8 +17,8 @@ import (
        "golang.zx2c4.com/wireguard/rwcancel"
        "net"
        "os"
-       "strconv"
        "sync"
+       "syscall"
        "time"
        "unsafe"
 )
@@ -30,8 +30,6 @@ const (
 
 type NativeTun struct {
        tunFile                 *os.File
-       fd                      uintptr
-       fdCancel                *rwcancel.RWCancel
        index                   int32         // if index
        name                    string        // name of interface
        errors                  chan error    // async error handling
@@ -52,9 +50,17 @@ func (tun *NativeTun) routineHackListener() {
        /* This is needed for the detection to work across network namespaces
         * If you are reading this and know a better method, please get in touch.
         */
-       fd := int(tun.fd)
        for {
-               _, err := unix.Write(fd, nil)
+               sysconn, err := tun.tunFile.SyscallConn()
+               if err != nil {
+                       return
+               }
+               err2 := sysconn.Control(func(fd uintptr) {
+                       _, err = unix.Write(int(fd), nil)
+               })
+               if err2 != nil {
+                       return
+               }
                switch err {
                case unix.EINVAL:
                        tun.events <- TUNEventUp
@@ -248,22 +254,32 @@ func (tun *NativeTun) MTU() (int, error) {
                uintptr(unsafe.Pointer(&ifr[0])),
        )
        if errno != 0 {
-               return 0, errors.New("failed to get MTU of TUN device: " + strconv.FormatInt(int64(errno), 10))
+               return 0, errors.New("failed to get MTU of TUN device: " + errno.Error())
        }
 
        return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
 }
 
 func (tun *NativeTun) Name() (string, error) {
+       sysconn, err := tun.tunFile.SyscallConn()
+       if err != nil {
+               return "", err
+       }
        var ifr [ifReqSize]byte
-       _, _, errno := unix.Syscall(
-               unix.SYS_IOCTL,
-               tun.fd,
-               uintptr(unix.TUNGETIFF),
-               uintptr(unsafe.Pointer(&ifr[0])),
-       )
+       var errno syscall.Errno
+       err = sysconn.Control(func(fd uintptr) {
+               _, _, errno = unix.Syscall(
+                       unix.SYS_IOCTL,
+                       fd,
+                       uintptr(unix.TUNGETIFF),
+                       uintptr(unsafe.Pointer(&ifr[0])),
+               )
+       })
+       if err != nil {
+               return "", errors.New("failed to get name of TUN device: " + err.Error())
+       }
        if errno != 0 {
-               return "", errors.New("failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10))
+               return "", errors.New("failed to get name of TUN device: " + errno.Error())
        }
        nullStr := ifr[:]
        i := bytes.IndexByte(nullStr, 0)
@@ -302,7 +318,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.tunFile.Write(buff)
 }
 
-func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
        select {
        case err := <-tun.errors:
                return 0, err
@@ -320,18 +336,6 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
        }
 }
 
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
-       for {
-               n, err := tun.doRead(buff, offset)
-               if err == nil || !rwcancel.RetryAfterError(err) {
-                       return n, err
-               }
-               if !tun.fdCancel.ReadyRead() {
-                       return 0, errors.New("tun device closed")
-               }
-       }
-}
-
 func (tun *NativeTun) Events() chan TUNEvent {
        return tun.events
 }
@@ -347,15 +351,11 @@ func (tun *NativeTun) Close() error {
                close(tun.events)
        }
        err2 := tun.tunFile.Close()
-       err3 := tun.fdCancel.Cancel()
 
        if err1 != nil {
                return err1
        }
-       if err2 != nil {
-               return err2
-       }
-       return err3
+       return err2
 }
 
 func CreateTUN(name string, mtu int) (TUNDevice, error) {
@@ -364,13 +364,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
                return nil, err
        }
 
-       fd := os.NewFile(uintptr(nfd), cloneDevicePath)
-       if err != nil {
-               return nil, err
-       }
-
-       // create new device
-
        var ifr [ifReqSize]byte
        var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
        nameBytes := []byte(name)
@@ -382,13 +375,21 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
 
        _, _, errno := unix.Syscall(
                unix.SYS_IOCTL,
-               fd.Fd(),
+               uintptr(nfd),
                uintptr(unix.TUNSETIFF),
                uintptr(unsafe.Pointer(&ifr[0])),
        )
        if errno != 0 {
                return nil, errno
        }
+       err = unix.SetNonblock(nfd, true)
+
+       // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
+
+       fd := os.NewFile(uintptr(nfd), cloneDevicePath)
+       if err != nil {
+               return nil, err
+       }
 
        return CreateTUNFromFile(fd, mtu)
 }
@@ -396,7 +397,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
 func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
        tun := &NativeTun{
                tunFile:                 file,
-               fd:                      file.Fd(),
                events:                  make(chan TUNEvent, 5),
                errors:                  make(chan error, 5),
                statusListenersShutdown: make(chan struct{}),
@@ -404,11 +404,6 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
        }
        var err error
 
-       tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd))
-       if err != nil {
-               return nil, err
-       }
-
        _, err = tun.Name()
        if err != nil {
                return nil, err
@@ -444,23 +439,20 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
        return tun, nil
 }
 
-func CreateUnmonitoredTUNFromFD(tunFd int) (TUNDevice, string, error) {
-       file := os.NewFile(uintptr(tunFd), "/dev/tun")
+func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) {
+       err := unix.SetNonblock(fd, true)
+       if err != nil {
+               return nil, "", err
+       }
+       file := os.NewFile(uintptr(fd), "/dev/tun")
        tun := &NativeTun{
                tunFile: file,
-               fd:      file.Fd(),
                events:  make(chan TUNEvent, 5),
                errors:  make(chan error, 5),
                nopi:    true,
        }
-       var err error
-       tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd))
-       if err != nil {
-               return nil, "", err
-       }
        name, err := tun.Name()
        if err != nil {
-               tun.fdCancel.Cancel()
                return nil, "", err
        }
        return tun, name, nil