]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
wireguard-go-bridge: account for network changes
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 7 Dec 2018 20:47:19 +0000 (21:47 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 7 Dec 2018 20:50:19 +0000 (21:50 +0100)
Everytime the network changes, we need to recreate the UDP socket,
because the ephemeral listen port is tied to the old physical interface.
As well, we need to re-set the IP addresses for each endpoint, so that
they're passed to getaddrinfo and are then resolved using DNS46.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
wireguard-go-bridge/src/api-ios.go

index 67ce7856c39d9aec143375da41490ba32b3ae09a..0fdb3be31394465a70300e5ee91a5813bcaf991f 100644 (file)
@@ -25,6 +25,8 @@ import (
        "os/signal"
        "runtime"
        "strings"
+       "syscall"
+       "time"
        "unsafe"
 )
 
@@ -46,12 +48,54 @@ func (l *CLogger) Write(p []byte) (int, error) {
        return len(p), nil
 }
 
-var tunnelHandles map[int32]*Device
+type DeviceState struct {
+       device            *Device
+       logger            *Logger
+       endpointsTimer    *time.Timer
+       endpointsSettings string
+}
+
+var tunnelHandles map[int32]*DeviceState
+
+func listenForRouteChanges() {
+       //TODO: replace with NWPathMonitor
+       data := make([]byte, os.Getpagesize())
+       routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+       if err != nil {
+               return
+       }
+       for {
+               n, err := unix.Read(routeSocket, data)
+               if err != nil {
+                       if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+                               continue
+                       }
+                       return
+               }
+
+               if n < 4 {
+                       continue
+               }
+               for _, deviceState := range tunnelHandles {
+                       if deviceState.endpointsTimer == nil {
+                               deviceState.endpointsTimer = time.AfterFunc(time.Second, func() {
+                                       deviceState.endpointsTimer = nil
+                                       bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(deviceState.endpointsSettings)), bufio.NewWriter(ioutil.Discard))
+                                       deviceState.logger.Info.Println("Setting endpoints for re-resolution due to network change")
+                                       err := ipcSetOperation(deviceState.device, bufferedSettings)
+                                       if err != nil {
+                                               deviceState.logger.Error.Println(err)
+                                       }
+                               })
+                       }
+               }
+       }
+}
 
 func init() {
        versionString = C.CString(WireGuardGoVersion)
        roamingDisabled = true
-       tunnelHandles = make(map[int32]*Device)
+       tunnelHandles = make(map[int32]*DeviceState)
        signals := make(chan os.Signal)
        signal.Notify(signals, unix.SIGUSR2)
        go func() {
@@ -67,6 +111,7 @@ func init() {
                        }
                }
        }()
+       go listenForRouteChanges()
 }
 
 //export wgSetLogger
@@ -74,6 +119,32 @@ func wgSetLogger(loggerFn uintptr) {
        loggerFunc = unsafe.Pointer(loggerFn)
 }
 
+func extractEndpointFromSettings(settings string) string {
+       var b strings.Builder
+       pubkey := ""
+       endpoint := ""
+       listenPort := "listen_port=0"
+       for _, line := range strings.Split(settings, "\n") {
+               if strings.HasPrefix(line, "listen_port=") {
+                       listenPort = line
+               } else if strings.HasPrefix(line, "public_key=") {
+                       if pubkey != "" && endpoint != "" {
+                               b.WriteString(pubkey + "\n" + endpoint + "\n")
+                       }
+                       pubkey = line
+               } else if strings.HasPrefix(line, "endpoint=") {
+                       endpoint = line
+               } else if line == "remove=true" {
+                       pubkey = ""
+                       endpoint = ""
+               }
+       }
+       if pubkey != "" && endpoint != "" {
+               b.WriteString(pubkey + "\n" + endpoint + "\n")
+       }
+       return listenPort + "\n" + b.String()
+}
+
 //export wgTurnOn
 func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
        interfaceName := string([]byte(ifnameRef))
@@ -113,18 +184,27 @@ func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
        if i == math.MaxInt32 {
                return -1
        }
-       tunnelHandles[i] = device
+       tunnelHandles[i] = &DeviceState{
+               device:            device,
+               logger:            logger,
+               endpointsSettings: extractEndpointFromSettings(settings),
+       }
        return i
 }
 
 //export wgTurnOff
 func wgTurnOff(tunnelHandle int32) {
-       device, ok := tunnelHandles[tunnelHandle]
+       deviceState, ok := tunnelHandles[tunnelHandle]
        if !ok {
                return
        }
        delete(tunnelHandles, tunnelHandle)
-       device.Close()
+       t := deviceState.endpointsTimer
+       if t != nil {
+               deviceState.endpointsTimer = nil
+               t.Stop()
+       }
+       deviceState.device.Close()
 }
 
 //export wgVersion