]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Introduce rwcancel
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 22:28:30 +0000 (00:28 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 22:30:52 +0000 (00:30 +0200)
Makefile
misc.go
rwcancel/rwcancel_unix.go [new file with mode: 0644]
tun_linux.go
uapi_linux.go

index 77eaac9f71785698bc718e87de0a7f3bc600c717..1513ef5e7bc6399885448c4c849f9da426ad4b52 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
 all: wireguard-go
 
-wireguard-go: $(wildcard *.go)
+wireguard-go: $(wildcard *.go) $(wildcard */*.go)
        go build -o $@
 
 clean:
diff --git a/misc.go b/misc.go
index f94a617fa3845ec9ce8146718faab4895bfb86aa..85a2e803376d923385b14480e5633c97b99675b6 100644 (file)
--- a/misc.go
+++ b/misc.go
@@ -47,7 +47,7 @@ func toInt32(n uint32) int32 {
        return int32(-(n & mask) + (n & ^mask))
 }
 
-func min(a uint, b uint) uint {
+func min(a, b uint) uint {
        if a > b {
                return b
        }
diff --git a/rwcancel/rwcancel_unix.go b/rwcancel/rwcancel_unix.go
new file mode 100644 (file)
index 0000000..cd3661f
--- /dev/null
@@ -0,0 +1,132 @@
+/* SPDX-License-Identifier: GPL-2.0
+ *
+ * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+package rwcancel
+
+import (
+       "errors"
+       "golang.org/x/sys/unix"
+       "os"
+       "runtime"
+       "syscall"
+)
+
+type RWCancel struct {
+       fd            int
+       closingReader *os.File
+       closingWriter *os.File
+}
+
+type fdSet struct {
+       fdset unix.FdSet
+}
+
+func (fdset *fdSet) set(i int) {
+       bits := 32 << (^uint(0) >> 63)
+       fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
+}
+
+func (fdset *fdSet) check(i int) bool {
+       bits := 32 << (^uint(0) >> 63)
+       return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
+}
+
+func max(a, b int) int {
+       if a > b {
+               return a
+       }
+       return b
+}
+
+func NewRWCancel(fd int) (*RWCancel, error) {
+       err := unix.SetNonblock(fd, true)
+       if err != nil {
+               return nil, err
+       }
+       rwcancel := RWCancel{fd: fd}
+
+       rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe()
+       if err != nil {
+               return nil, err
+       }
+
+       runtime.SetFinalizer(&rwcancel, func(rw *RWCancel) {
+               rw.Cancel()
+       })
+
+       return &rwcancel, nil
+}
+
+/* https://golang.org/src/crypto/rand/eagain.go */
+func ErrorIsEAGAIN(err error) bool {
+       if pe, ok := err.(*os.PathError); ok {
+               if errno, ok := pe.Err.(syscall.Errno); ok && errno == syscall.EAGAIN {
+                       return true
+               }
+       }
+       if errno, ok := err.(syscall.Errno); ok && errno == syscall.EAGAIN {
+               return true
+       }
+       return false
+}
+
+func (rw *RWCancel) ReadyRead() bool {
+       closeFd := int(rw.closingReader.Fd())
+       fdset := fdSet{}
+       fdset.set(rw.fd)
+       fdset.set(closeFd)
+       _, err := unix.Select(max(rw.fd, closeFd)+1, &fdset.fdset, nil, nil, nil)
+       if err != nil {
+               return false
+       }
+       if fdset.check(closeFd) {
+               return false
+       }
+       return fdset.check(rw.fd)
+}
+
+func (rw *RWCancel) ReadyWrite() bool {
+       closeFd := int(rw.closingReader.Fd())
+       fdset := fdSet{}
+       fdset.set(rw.fd)
+       fdset.set(closeFd)
+       _, err := unix.Select(max(rw.fd, closeFd)+1, nil, &fdset.fdset, nil, nil)
+       if err != nil {
+               return false
+       }
+       if fdset.check(closeFd) {
+               return false
+       }
+       return fdset.check(rw.fd)
+}
+
+func (rw *RWCancel) Read(p []byte) (n int, err error) {
+       for {
+               n, err := unix.Read(rw.fd, p)
+               if err == nil || !ErrorIsEAGAIN(err) {
+                       return n, err
+               }
+               if !rw.ReadyRead() {
+                       return 0, errors.New("fd closed")
+               }
+       }
+}
+
+func (rw *RWCancel) Write(p []byte) (n int, err error) {
+       for {
+               n, err := unix.Write(rw.fd, p)
+               if err == nil || !ErrorIsEAGAIN(err) {
+                       return n, err
+               }
+               if !rw.ReadyWrite() {
+                       return 0, errors.New("fd closed")
+               }
+       }
+}
+
+func (rw *RWCancel) Cancel() (err error) {
+       _, err = rw.closingWriter.Write([]byte{0})
+       return
+}
index 9f60d2b7cf27bcaa4d71929c0ed0b4cd5c5acda3..3510f94695eb3643ff0d0c98c6187690b948d5c8 100644 (file)
@@ -11,6 +11,7 @@ package main
  */
 
 import (
+       "./rwcancel"
        "bytes"
        "encoding/binary"
        "errors"
@@ -20,7 +21,6 @@ import (
        "net"
        "os"
        "strconv"
-       "syscall"
        "time"
        "unsafe"
 )
@@ -31,14 +31,13 @@ const (
 )
 
 type NativeTun struct {
-       fd            *os.File
-       index         int32         // if index
-       name          string        // name of interface
-       errors        chan error    // async error handling
-       events        chan TUNEvent // device related events
-       nopi          bool          // the device was pased IFF_NO_PI
-       closingReader *os.File
-       closingWriter *os.File
+       fd       *os.File
+       index    int32         // if index
+       name     string        // name of interface
+       errors   chan error    // async error handling
+       events   chan TUNEvent // device related events
+       nopi     bool          // the device was pased IFF_NO_PI
+       rwcancel *rwcancel.RWCancel
 }
 
 func (tun *NativeTun) File() *os.File {
@@ -305,43 +304,6 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.fd.Write(buff)
 }
 
-type FdSet struct {
-       fdset unix.FdSet
-}
-
-func (fdset *FdSet) set(i int) {
-       bits := 32 << (^uint(0) >> 63)
-       fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
-}
-
-func (fdset *FdSet) check(i int) bool {
-       bits := 32 << (^uint(0) >> 63)
-       return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
-}
-
-func max(a, b int) int {
-       if a > b {
-               return a
-       }
-       return b
-}
-
-func (tun *NativeTun) readyRead() bool {
-       readFd := int(tun.fd.Fd())
-       closeFd := int(tun.closingReader.Fd())
-       fdset := FdSet{}
-       fdset.set(readFd)
-       fdset.set(closeFd)
-       _, err := unix.Select(max(readFd, closeFd)+1, &fdset.fdset, nil, nil, nil)
-       if err != nil {
-               return false
-       }
-       if fdset.check(closeFd) {
-               return false
-       }
-       return fdset.check(readFd)
-}
-
 func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
        select {
        case err := <-tun.errors:
@@ -360,24 +322,14 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
        }
 }
 
-/* https://golang.org/src/crypto/rand/eagain.go */
-func unixIsEAGAIN(err error) bool {
-       if pe, ok := err.(*os.PathError); ok {
-               if errno, ok := pe.Err.(syscall.Errno); ok && errno == syscall.EAGAIN {
-                       return true
-               }
-       }
-       return false
-}
-
 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
        for {
                n, err := tun.doRead(buff, offset)
-               if err == nil || !unixIsEAGAIN(err) {
+               if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
                        return n, err
                }
-               if !tun.readyRead() {
-                       return 0, errors.New("Tun device closed")
+               if !tun.rwcancel.ReadyRead() {
+                       return 0, errors.New("tun device closed")
                }
        }
 }
@@ -391,7 +343,7 @@ func (tun *NativeTun) Close() error {
        if err != nil {
                return err
        }
-       tun.closingWriter.Write([]byte{0})
+       tun.rwcancel.Cancel()
        close(tun.events)
        return nil
 }
@@ -450,7 +402,7 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
        }
        var err error
 
-       err = unix.SetNonblock(int(fd.Fd()), true)
+       device.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
        if err != nil {
                return nil, err
        }
@@ -460,11 +412,6 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
                return nil, err
        }
 
-       device.closingReader, device.closingWriter, err = os.Pipe()
-       if err != nil {
-               return nil, err
-       }
-
        // start event listener
 
        device.index, err = getIFIndex(device.name)
index c40472e3b922cebb3515e869fb4ffe08eba4c19f..67024e996a558855bcda6da608861b547b07b3ec 100644 (file)
@@ -6,6 +6,7 @@
 package main
 
 import (
+       "./rwcancel"
        "errors"
        "fmt"
        "golang.org/x/sys/unix"
@@ -24,10 +25,11 @@ const (
 )
 
 type UAPIListener struct {
-       listener  net.Listener // unix socket listener
-       connNew   chan net.Conn
-       connErr   chan error
-       inotifyFd int
+       listener        net.Listener // unix socket listener
+       connNew         chan net.Conn
+       connErr         chan error
+       inotifyFd       int
+       inotifyRWCancel *rwcancel.RWCancel
 }
 
 func (l *UAPIListener) Accept() (net.Conn, error) {
@@ -45,10 +47,14 @@ func (l *UAPIListener) Accept() (net.Conn, error) {
 func (l *UAPIListener) Close() error {
        err1 := unix.Close(l.inotifyFd)
        err2 := l.listener.Close()
+       err3 := l.inotifyRWCancel.Cancel()
        if err1 != nil {
                return err1
        }
-       return err2
+       if err2 != nil {
+               return err2
+       }
+       return err3
 }
 
 func (l *UAPIListener) Addr() net.Addr {
@@ -94,15 +100,25 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
                return nil, err
        }
 
+       uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd)
+       if err != nil {
+               unix.Close(uapi.inotifyFd)
+               return nil, err
+       }
+
        go func(l *UAPIListener) {
-               var buff [4096]byte
+               var buff [0]byte
                for {
                        // start with lstat to avoid race condition
                        if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
                                l.connErr <- err
                                return
                        }
-                       unix.Read(uapi.inotifyFd, buff[:])
+                       _, err := uapi.inotifyRWCancel.Read(buff[:])
+                       if err != nil {
+                               l.connErr <- err
+                               return
+                       }
                }
        }(uapi)