]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Moved TUN device creation to pre-fork
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 14 Nov 2017 17:26:28 +0000 (18:26 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 14 Nov 2017 17:26:28 +0000 (18:26 +0100)
src/daemon_linux.go
src/device.go
src/main.go
src/tests/netns.sh
src/tun.go
src/tun_linux.go

index 730f89efa5d36689a2e53fad6e49956369a4cb0b..8210f8b8b281fd04fdc5d809482faa356602fc99 100644 (file)
@@ -11,18 +11,9 @@ import (
  * TODO: Use env variable to spawn in background
  */
 
-func Daemonize() error {
+func Daemonize(attr *os.ProcAttr) error {
        argv := []string{os.Args[0], "--foreground"}
        argv = append(argv, os.Args[1:]...)
-       attr := &os.ProcAttr{
-               Dir: ".",
-               Env: os.Environ(),
-               Files: []*os.File{
-                       os.Stdin,
-                       nil,
-                       nil,
-               },
-       }
        process, err := os.StartProcess(
                argv[0],
                argv,
index 9422d4959d743e310a7e3236629c43c3fad523af..429ee463d6effd2fc1d1a1a00023d85bddc56df6 100644 (file)
@@ -126,13 +126,13 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
        device.pool.messageBuffers.Put(msg)
 }
 
-func NewDevice(tun TUNDevice, logLevel int) *Device {
+func NewDevice(tun TUNDevice, logger *Logger) *Device {
        device := new(Device)
 
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       device.log = NewLogger(logLevel, "("+tun.Name()+") ")
+       device.log = logger
        device.peers = make(map[NoisePublicKey]*Peer)
        device.tun.device = tun
 
index eb3c67f10b1d10a96f1b16a5ba399760ee0c0843..3808c9cb9b4ccc8f4e86e4df215d474edc5c0b32 100644 (file)
@@ -2,10 +2,14 @@ package main
 
 import (
        "fmt"
-       "log"
        "os"
        "os/signal"
        "runtime"
+       "strconv"
+)
+
+const (
+       EnvWGTunFD = "WG_TUN_FD"
 )
 
 func printUsage() {
@@ -43,28 +47,6 @@ func main() {
                interfaceName = os.Args[1]
        }
 
-       // daemonize the process
-
-       if !foreground {
-               err := Daemonize()
-               if err != nil {
-                       log.Println("Failed to daemonize:", err)
-               }
-               return
-       }
-
-       // increase number of go workers (for Go <1.5)
-
-       runtime.GOMAXPROCS(runtime.NumCPU())
-
-       // open TUN device
-
-       tun, err := CreateTUN(interfaceName)
-       if err != nil {
-               log.Println("Failed to create tun device:", err)
-               return
-       }
-
        // get log level (default: info)
 
        logLevel := func() int {
@@ -79,22 +61,76 @@ func main() {
                return LogLevelInfo
        }()
 
-       // create wireguard device
+       logger := NewLogger(
+               logLevel,
+               fmt.Sprintf("(%s) ", interfaceName),
+       )
+       logger.Debug.Println("Debug log enabled")
 
-       device := NewDevice(tun, logLevel)
+       // open TUN device
 
-       logInfo := device.log.Info
-       logError := device.log.Error
-       logDebug := device.log.Debug
+       tun, err := func() (TUNDevice, error) {
+               tunFdStr := os.Getenv(EnvWGTunFD)
+               if tunFdStr == "" {
+                       return CreateTUN(interfaceName)
+               }
 
-       logInfo.Println("Device started")
-       logDebug.Println("Debug log enabled")
+               // construct tun device from supplied FD
+
+               fd, err := strconv.ParseUint(tunFdStr, 10, 32)
+               if err != nil {
+                       return nil, err
+               }
+
+               file := os.NewFile(uintptr(fd), "/dev/net/tun")
+               return CreateTUNFromFile(interfaceName, file)
+       }()
+
+       if err != nil {
+               logger.Error.Println("Failed to create TUN device:", err)
+       }
+
+       // daemonize the process
+
+       if !foreground {
+               env := os.Environ()
+               _, ok := os.LookupEnv(EnvWGTunFD)
+               if !ok {
+                       kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
+                       env = append(env, kvp)
+               }
+               attr := &os.ProcAttr{
+                       Files: []*os.File{
+                               nil, // stdin
+                               nil, // stdout
+                               nil, // stderr
+                               tun.File(),
+                       },
+                       Dir: ".",
+                       Env: env,
+               }
+               err = Daemonize(attr)
+               if err != nil {
+                       logger.Error.Println("Failed to daemonize:", err)
+               }
+               return
+       }
+
+       // increase number of go workers (for Go <1.5)
+
+       runtime.GOMAXPROCS(runtime.NumCPU())
+
+       // create wireguard device
+
+       device := NewDevice(tun, logger)
+       logger.Info.Println("Device started")
 
        // start configuration lister
 
        uapi, err := NewUAPIListener(interfaceName)
        if err != nil {
-               logError.Fatal("UAPI listen error:", err)
+               logger.Error.Println("UAPI listen error:", err)
+               return
        }
 
        errs := make(chan error)
@@ -112,7 +148,7 @@ func main() {
                }
        }()
 
-       logInfo.Println("UAPI listener started")
+       logger.Info.Println("UAPI listener started")
 
        // wait for program to terminate
 
@@ -129,5 +165,5 @@ func main() {
 
        uapi.Close()
 
-       logInfo.Println("Closing")
+       logger.Info.Println("Shutting down")
 }
index 9124b80dc1f5c5269b36b2b2e8090ac67df69d2e..b5c2f9c28b17f000bdc8a403c51fe820437bdb08 100755 (executable)
@@ -28,7 +28,7 @@ netns0="wg-test-$$-0"
 netns1="wg-test-$$-1"
 netns2="wg-test-$$-2"
 program="../wireguard-go"
-export LOG_LEVEL="debug"
+export LOG_LEVEL="info"
 
 pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
 pp() { pretty "" "$*"; "$@"; }
@@ -72,13 +72,11 @@ pp ip netns add $netns2
 ip0 link set up dev lo
 
 # ip0 link add dev wg1 type wireguard
-n0 $program -f wg1 &
-sleep 1
+n0 $program wg1
 ip0 link set wg1 netns $netns1
 
 # ip0 link add dev wg1 type wireguard
-n0 $program -f wg2 &
-sleep 1
+n0 $program wg2
 ip0 link set wg2 netns $netns2
 
 key1="$(pp wg genkey)"
@@ -147,8 +145,6 @@ tests() {
     n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2
 }
 
-echo "4"
-
 [[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}"
 big_mtu=$(( 34816 - 1500 + $orig_mtu ))
 
@@ -234,9 +230,8 @@ ip2 link del wg2
 # ip1 link add dev wg1 type wireguard
 # ip2 link add dev wg1 type wireguard
 
-n1 $program -f wg1 &
-n2 $program -f wg2 &
-sleep 5
+n1 $program wg1
+n2 $program wg2
 
 configure_peers
 
@@ -291,9 +286,8 @@ ip2 link del wg2
 
 # ip1 link add dev wg1 type wireguard
 # ip2 link add dev wg1 type wireguard
-n1 $program -f wg1 &
-n2 $program -f wg2 &
-sleep 5
+n1 $program wg1
+n2 $program wg2
 
 configure_peers
 
@@ -354,4 +348,5 @@ n2 ping -W 1 -c 1 192.168.241.1
 ip1 link del veth1
 ip1 link del wg1
 ip2 link del wg2
+
 echo "done"
index 9eed98747829e7fe87924afdc6fd0285821052c2..5bdac0ed38520d6a3b3ad4e67a4fe9e7950b6897 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "os"
        "sync/atomic"
 )
 
@@ -15,6 +16,7 @@ const (
 )
 
 type TUNDevice interface {
+       File() *os.File            // returns the file descriptor of the device
        Read([]byte) (int, error)  // read a packet from the device (without any additional headers)
        Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
        MTU() (int, error)         // returns the MTU of the device
index accc6c6064447910e291d20e281606569c87d9e5..ce6304c48aafaee73fb213b385686c8450ea7254 100644 (file)
@@ -56,6 +56,11 @@ type NativeTun struct {
        events chan TUNEvent // device related events
 }
 
+func (tun *NativeTun) File() *os.File {
+       println(tun.fd.Name())
+       return tun.fd
+}
+
 func (tun *NativeTun) RoutineNetlinkListener() {
        sock := int(C.bind_rtmgrp())
        if sock < 0 {
@@ -248,6 +253,29 @@ func (tun *NativeTun) Close() error {
        return nil
 }
 
+func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
+       device := &NativeTun{
+               fd:     fd,
+               name:   name,
+               events: make(chan TUNEvent, 5),
+               errors: make(chan error, 5),
+       }
+
+       // start event listener
+
+       var err error
+       device.index, err = getIFIndex(device.name)
+       if err != nil {
+               return nil, err
+       }
+
+       go device.RoutineNetlinkListener()
+
+       // set default MTU
+
+       return device, device.setMTU(DefaultMTU)
+}
+
 func CreateTUN(name string) (TUNDevice, error) {
 
        // open clone device