]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Allows passing UAPI fd to service
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 17 Nov 2017 13:36:08 +0000 (14:36 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 17 Nov 2017 13:36:08 +0000 (14:36 +0100)
src/main.go
src/tun_linux.go
src/uapi_linux.go

index 3808c9cb9b4ccc8f4e86e4df215d474edc5c0b32..7d86716de66b7d3811127a53221656c2a8d5d23d 100644 (file)
@@ -9,7 +9,8 @@ import (
 )
 
 const (
-       EnvWGTunFD = "WG_TUN_FD"
+       ENV_WG_TUN_FD  = "WG_TUN_FD"
+       ENV_WG_UAPI_FD = "WG_UAPI_FD"
 )
 
 func printUsage() {
@@ -65,46 +66,69 @@ func main() {
                logLevel,
                fmt.Sprintf("(%s) ", interfaceName),
        )
+
        logger.Debug.Println("Debug log enabled")
 
-       // open TUN device
+       // open TUN device (or use supplied fd)
 
        tun, err := func() (TUNDevice, error) {
-               tunFdStr := os.Getenv(EnvWGTunFD)
+               tunFdStr := os.Getenv(ENV_WG_TUN_FD)
                if tunFdStr == "" {
                        return CreateTUN(interfaceName)
                }
 
-               // construct tun device from supplied FD
+               // construct tun device from supplied fd
 
                fd, err := strconv.ParseUint(tunFdStr, 10, 32)
                if err != nil {
                        return nil, err
                }
 
-               file := os.NewFile(uintptr(fd), "/dev/net/tun")
+               file := os.NewFile(uintptr(fd), "")
                return CreateTUNFromFile(interfaceName, file)
        }()
 
        if err != nil {
                logger.Error.Println("Failed to create TUN device:", err)
+               os.Exit(ExitSetupFailed)
        }
 
+       // open UAPI file (or use supplied fd)
+
+       fileUAPI, err := func() (*os.File, error) {
+               uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
+               if uapiFdStr == "" {
+                       return UAPIOpen(interfaceName)
+               }
+
+               // use supplied fd
+
+               fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
+               if err != nil {
+                       return nil, err
+               }
+
+               return os.NewFile(uintptr(fd), ""), nil
+       }()
+
+       if err != nil {
+               logger.Error.Println("UAPI listen error:", err)
+               os.Exit(ExitSetupFailed)
+               return
+       }
        // daemonize the process
 
        if !foreground {
                env := os.Environ()
-               _, ok := os.LookupEnv(EnvWGTunFD)
-               if !ok {
-                       kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
-                       env = append(env, kvp)
-               }
+               env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
+               env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
                attr := &os.ProcAttr{
                        Files: []*os.File{
                                nil, // stdin
                                nil, // stdout
                                nil, // stderr
                                tun.File(),
+                               fileUAPI,
                        },
                        Dir: ".",
                        Env: env,
@@ -112,6 +136,7 @@ func main() {
                err = Daemonize(attr)
                if err != nil {
                        logger.Error.Println("Failed to daemonize:", err)
+                       os.Exit(ExitSetupFailed)
                }
                return
        }
@@ -123,20 +148,17 @@ func main() {
        // create wireguard device
 
        device := NewDevice(tun, logger)
+
        logger.Info.Println("Device started")
 
-       // start configuration lister
-
-       uapi, err := NewUAPIListener(interfaceName)
-       if err != nil {
-               logger.Error.Println("UAPI listen error:", err)
-               return
-       }
+       // start uapi listener
 
        errs := make(chan error)
        term := make(chan os.Signal)
        wait := device.WaitChannel()
 
+       uapi, err := UAPIListen(interfaceName, fileUAPI)
+
        go func() {
                for {
                        conn, err := uapi.Accept()
@@ -161,9 +183,10 @@ func main() {
        case <-errs:
        }
 
-       // clean up UAPI bind
+       // clean up
 
        uapi.Close()
+       device.Close()
 
        logger.Info.Println("Shutting down")
 }
index ce6304c48aafaee73fb213b385686c8450ea7254..2a5b276531b6fe56c8026e8b7fc9afd572cb979e 100644 (file)
@@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) {
 
        val := binary.LittleEndian.Uint32(ifr[16:20])
        if val >= (1 << 31) {
-               return int(val-(1<<31)) - (1 << 31), nil
+               return int(toInt32(val)), nil
        }
        return int(val), nil
 }
index cb9d858f8ff15a7b214c21edf6e7dad14eac6dfd..f97a18a61de87e9250a0fb6a30cc329af5d6847a 100644 (file)
@@ -10,12 +10,12 @@ import (
 )
 
 const (
-       ipcErrorIO         = -int64(unix.EIO)
-       ipcErrorProtocol   = -int64(unix.EPROTO)
-       ipcErrorInvalid    = -int64(unix.EINVAL)
-       ipcErrorPortInUse  = -int64(unix.EADDRINUSE)
-       socketDirectory    = "/var/run/wireguard"
-       socketName         = "%s.sock"
+       ipcErrorIO        = -int64(unix.EIO)
+       ipcErrorProtocol  = -int64(unix.EPROTO)
+       ipcErrorInvalid   = -int64(unix.EINVAL)
+       ipcErrorPortInUse = -int64(unix.EADDRINUSE)
+       socketDirectory   = "/var/run/wireguard"
+       socketName        = "%s.sock"
 )
 
 type UAPIListener struct {
@@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
        return nil
 }
 
-func connectUnixSocket(path string) (net.Listener, error) {
+func UAPIListen(name string, file *os.File) (net.Listener, error) {
 
-       // attempt inital connection
+       // wrap file in listener
 
-       listener, err := net.Listen("unix", path)
-       if err == nil {
-               return listener, nil
-       }
-
-       // check if active
-
-       _, err = net.Dial("unix", path)
-       if err == nil {
-               return nil, errors.New("Unix socket in use")
-       }
-
-       // attempt cleanup
-
-       err = os.Remove(path)
-       if err != nil {
-               return nil, err
-       }
-
-       return net.Listen("unix", path)
-}
-
-func NewUAPIListener(name string) (net.Listener, error) {
-
-       // check if path exist
-
-       err := os.MkdirAll(socketDirectory, 077)
-       if err != nil && !os.IsExist(err) {
-               return nil, err
-       }
-
-       // open UNIX socket
-
-       socketPath := path.Join(
-               socketDirectory,
-               fmt.Sprintf(socketName, name),
-       )
-
-       listener, err := connectUnixSocket(socketPath)
+       listener, err := net.FileListener(file)
        if err != nil {
                return nil, err
        }
@@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
 
        // watch for deletion of socket
 
+       socketPath := path.Join(
+               socketDirectory,
+               fmt.Sprintf(socketName, name),
+       )
+
        uapi.inotifyFd, err = unix.InotifyInit()
        if err != nil {
                return nil, err
@@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
        go func(l *UAPIListener) {
                var buff [4096]byte
                for {
-                       unix.Read(uapi.inotifyFd, buff[:])
+                       // start with lstat to avoid race condition
                        if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
                                l.connErr <- err
                                return
                        }
+                       unix.Read(uapi.inotifyFd, buff[:])
                }
        }(uapi)
 
@@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
 
        return uapi, nil
 }
+
+func UAPIOpen(name string) (*os.File, error) {
+
+       // check if path exist
+
+       err := os.MkdirAll(socketDirectory, 0600)
+       if err != nil && !os.IsExist(err) {
+               return nil, err
+       }
+
+       // open UNIX socket
+
+       socketPath := path.Join(
+               socketDirectory,
+               fmt.Sprintf(socketName, name),
+       )
+
+       addr, err := net.ResolveUnixAddr("unix", socketPath)
+       if err != nil {
+               return nil, err
+       }
+
+       listener, err := func() (*net.UnixListener, error) {
+
+               // initial connection attempt
+
+               listener, err := net.ListenUnix("unix", addr)
+               if err == nil {
+                       return listener, nil
+               }
+
+               // check if socket already active
+
+               _, err = net.Dial("unix", socketPath)
+               if err == nil {
+                       return nil, errors.New("unix socket in use")
+               }
+
+               // cleanup & attempt again
+
+               err = os.Remove(socketPath)
+               if err != nil {
+                       return nil, err
+               }
+               return net.ListenUnix("unix", addr)
+       }()
+
+       if err != nil {
+               return nil, err
+       }
+
+       return listener.File()
+}