]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Conforming to the cross-platform UX
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 15 Jul 2017 11:41:02 +0000 (13:41 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 15 Jul 2017 11:41:02 +0000 (13:41 +0200)
The implementation now terminates when the unix socket is deleted.
Currently we are unable to use fsnotify (on linux),
since it does not notify on the deletion of open files.

The implementation can now daemonize (on linux)
or be kept in the foreground by providing the necessary flag.

src/daemon_linux.go [new file with mode: 0644]
src/main.go
src/uapi_linux.go [new file with mode: 0644]

diff --git a/src/daemon_linux.go b/src/daemon_linux.go
new file mode 100644 (file)
index 0000000..809c176
--- /dev/null
@@ -0,0 +1,34 @@
+package main
+
+import (
+       "os"
+)
+
+/* Daemonizes the process on linux
+ *
+ * This is done by spawning and releasing a copy with the --foreground flag
+ */
+
+func Daemonize() error {
+       argv := []string{os.Args[0], "--foreground"}
+       argv = append(argv, os.Args[1:]...)
+       attr := &os.ProcAttr{
+               Dir: ".",
+               Env: os.Environ(),
+               Files: []*os.File{
+                       os.Stdin,
+                       nil,
+                       nil,
+               },
+       }
+       process, err := os.StartProcess(
+               argv[0],
+               argv,
+               attr,
+       )
+       if err != nil {
+               return err
+       }
+       process.Release()
+       return nil
+}
index dc27472266e56515c458bc65e54a8e1f8f0f75a9..74e7ec97c93f893345f4c0783a5fac546f642cdc 100644 (file)
@@ -1,23 +1,45 @@
 package main
 
 import (
-       "fmt"
        "log"
-       "net"
        "os"
        "runtime"
 )
 
-/* TODO: Fix logging
- * TODO: Fix daemon
- */
-
 func main() {
 
-       if len(os.Args) != 2 {
+       // parse arguments
+
+       var foreground bool
+       var interfaceName string
+       if len(os.Args) < 2 || len(os.Args) > 3 {
+               return
+       }
+
+       switch os.Args[1] {
+       case "-f", "--foreground":
+               foreground = true
+               if len(os.Args) != 3 {
+                       return
+               }
+               interfaceName = os.Args[2]
+       default:
+               foreground = false
+               if len(os.Args) != 2 {
+                       return
+               }
+               interfaceName = os.Args[1]
+       }
+
+       // daemonize the process
+
+       if !foreground {
+               err := Daemonize()
+               if err != nil {
+                       log.Println("Failed to daemonize:", err)
+               }
                return
        }
-       deviceName := os.Args[1]
 
        // increase number of go workers (for Go <1.5)
 
@@ -25,32 +47,33 @@ func main() {
 
        // open TUN device
 
-       tun, err := CreateTUN(deviceName)
+       tun, err := CreateTUN(interfaceName)
        log.Println(tun, err)
        if err != nil {
                return
        }
 
+       // create wireguard device
+
        device := NewDevice(tun, LogLevelDebug)
-       device.log.Info.Println("Starting device")
+
+       logInfo := device.log.Info
+       logError := device.log.Error
+       logInfo.Println("Starting device")
 
        // start configuration lister
 
-       go func() {
-               socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
-               l, err := net.Listen("unix", socketPath)
-               if err != nil {
-                       log.Fatal("listen error:", err)
-               }
+       uapi, err := NewUAPIListener(interfaceName)
+       if err != nil {
+               logError.Fatal("UAPI listen error:", err)
+       }
+       defer uapi.Close()
 
-               for {
-                       conn, err := l.Accept()
-                       if err != nil {
-                               log.Fatal("accept error:", err)
-                       }
-                       go ipcHandle(device, conn)
+       for {
+               conn, err := uapi.Accept()
+               if err != nil {
+                       logError.Fatal("accept error:", err)
                }
-       }()
-
-       device.Wait()
+               go ipcHandle(device, conn)
+       }
 }
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
new file mode 100644 (file)
index 0000000..ee6ee0b
--- /dev/null
@@ -0,0 +1,83 @@
+package main
+
+import (
+       "fmt"
+       "net"
+       "os"
+       "time"
+)
+
+/* TODO:
+ * This code can be improved by using fsnotify once:
+ * https://github.com/fsnotify/fsnotify/pull/205
+ * Is merged
+ */
+
+type UAPIListener struct {
+       listener net.Listener // unix socket listener
+       connNew  chan net.Conn
+       connErr  chan error
+}
+
+func (l *UAPIListener) Accept() (net.Conn, error) {
+       for {
+               select {
+               case conn := <-l.connNew:
+                       return conn, nil
+
+               case err := <-l.connErr:
+                       return nil, err
+               }
+       }
+}
+
+func (l *UAPIListener) Close() error {
+       return l.listener.Close()
+}
+
+func (l *UAPIListener) Addr() net.Addr {
+       return nil
+}
+
+func NewUAPIListener(name string) (net.Listener, error) {
+
+       // open UNIX socket
+
+       socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", name)
+       listener, err := net.Listen("unix", socketPath)
+       if err != nil {
+               return nil, err
+       }
+
+       uapi := &UAPIListener{
+               listener: listener,
+               connNew:  make(chan net.Conn, 1),
+               connErr:  make(chan error, 1),
+       }
+
+       // watch for deletion of socket
+
+       go func(l *UAPIListener) {
+               for ; ; time.Sleep(time.Second) {
+                       if _, err := os.Stat(socketPath); os.IsNotExist(err) {
+                               l.connErr <- err
+                               return
+                       }
+               }
+       }(uapi)
+
+       // watch for new connections
+
+       go func(l *UAPIListener) {
+               for {
+                       conn, err := l.listener.Accept()
+                       if err != nil {
+                               l.connErr <- err
+                               break
+                       }
+                       l.connNew <- conn
+               }
+       }(uapi)
+
+       return uapi, nil
+}