]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
NE: simplify logic
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 21 Dec 2018 14:56:03 +0000 (15:56 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 21 Dec 2018 14:56:03 +0000 (15:56 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift

index c418ebc92beda0dce38df21500dbd405c55e414f..3a9066d37cc4952edfbe22e0556dda194bd5eab4 100644 (file)
@@ -16,8 +16,9 @@ enum PacketTunnelProviderError: Error {
 class PacketTunnelProvider: NEPacketTunnelProvider {
     
     private var wgHandle: Int32?
-
     private var networkMonitor: NWPathMonitor?
+    private var lastFirstInterface: NWInterface?
+    private var packetTunnelSettingsGenerator: PacketTunnelSettingsGenerator?
 
     deinit {
         networkMonitor?.cancel()
@@ -65,7 +66,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
         }
         assert(endpoints.count == resolvedEndpoints.count)
 
-        let packetTunnelSettingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints)
+        packetTunnelSettingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints)
 
         let fileDescriptor = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32 //swiftlint:disable:this force_cast
         if fileDescriptor < 0 {
@@ -75,52 +76,23 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
             return
         }
 
-        let wireguardSettings = packetTunnelSettingsGenerator.uapiConfiguration()
-
-        var handle: Int32 = -1
-
-        func interfaceDescription(_ interface: NWInterface?) -> String {
-            if let interface = interface {
-                return "\(interface.name) (\(interface.type))"
-            } else {
-                return "None"
-            }
-        }
+        let wireguardSettings = packetTunnelSettingsGenerator!.uapiConfiguration()
 
         networkMonitor = NWPathMonitor()
-        var previousPrimaryNetworkPathInterface = networkMonitor?.currentPath.availableInterfaces.first
-        wg_log(.debug, message: "Network path primary interface: \(interfaceDescription(previousPrimaryNetworkPathInterface))")
-        networkMonitor?.pathUpdateHandler = { path in
-            guard handle >= 0 else { return }
-            if path.status == .satisfied {
-                wg_log(.debug, message: "Network change detected, re-establishing sockets and IPs: \(path.availableInterfaces)")
-                let primaryNetworkPathInterface = path.availableInterfaces.first
-                wg_log(.debug, message: "Network path primary interface: \(interfaceDescription(primaryNetworkPathInterface))")
-                let shouldIncludeListenPort = previousPrimaryNetworkPathInterface != primaryNetworkPathInterface
-                let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(shouldIncludeListenPort: shouldIncludeListenPort, currentListenPort: wgGetListenPort(handle))
-                let err = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
-                if err == -EADDRINUSE {
-                    // We expect this to happen only if shouldIncludeListenPort is true
-                    let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(shouldIncludeListenPort: shouldIncludeListenPort, currentListenPort: 0)
-                    _ = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
-                }
-                previousPrimaryNetworkPathInterface = primaryNetworkPathInterface
-            }
-        }
-        networkMonitor?.start(queue: DispatchQueue(label: "NetworkMonitor"))
-
-        handle = connect(interfaceName: tunnelConfiguration.interface.name, settings: wireguardSettings, fileDescriptor: fileDescriptor)
+        lastFirstInterface = networkMonitor!.currentPath.availableInterfaces.first
+        networkMonitor!.pathUpdateHandler = pathUpdate
+        networkMonitor!.start(queue: DispatchQueue(label: "NetworkMonitor"))
 
+        let handle = withStringsAsGoStrings(tunnelConfiguration.interface.name, wireguardSettings) { return wgTurnOn($0.0, $0.1, fileDescriptor) }
         if handle < 0 {
             wg_log(.error, staticMessage: "Starting tunnel failed: Could not start WireGuard")
             errorNotifier.notify(PacketTunnelProviderError.couldNotStartWireGuard)
             startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard)
             return
         }
-
         wgHandle = handle
 
-        let networkSettings: NEPacketTunnelNetworkSettings = packetTunnelSettingsGenerator.generateNetworkSettings()
+        let networkSettings: NEPacketTunnelNetworkSettings = packetTunnelSettingsGenerator!.generateNetworkSettings()
         setTunnelNetworkSettings(networkSettings) { error in
             if let error = error {
                 wg_log(.error, staticMessage: "Starting tunnel failed: Error setting network settings.")
@@ -165,8 +137,21 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
         }
     }
 
-    private func connect(interfaceName: String, settings: String, fileDescriptor: Int32) -> Int32 {
-        return withStringsAsGoStrings(interfaceName, settings) { return wgTurnOn($0.0, $0.1, fileDescriptor) }
+    private func pathUpdate(path: Network.NWPath) {
+        guard let handle = wgHandle, let packetTunnelSettingsGenerator = packetTunnelSettingsGenerator else { return }
+        var listenPort: UInt16?
+        if path.availableInterfaces.isEmpty || lastFirstInterface != path.availableInterfaces.first {
+            listenPort = wgGetListenPort(handle)
+            lastFirstInterface = path.availableInterfaces.first
+        }
+        guard path.status == .satisfied else { return }
+        wg_log(.debug, message: "Network change detected, re-establishing sockets and IPs: \(path.availableInterfaces)")
+        let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: listenPort)
+        let err = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
+        if err == -EADDRINUSE && listenPort != nil {
+            let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: 0)
+            _ = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
+        }
     }
 }
 
index 888769d67480fa00a62e931877637cd03ba42f3a..fd706d90a40aa6cfe25e9438fcef1cfadadb345c 100644 (file)
@@ -15,15 +15,11 @@ class PacketTunnelSettingsGenerator {
         self.resolvedEndpoints = resolvedEndpoints
     }
 
-    func endpointUapiConfiguration(shouldIncludeListenPort: Bool, currentListenPort: UInt16?) -> String {
+    func endpointUapiConfiguration(currentListenPort: UInt16?) -> String {
         var wgSettings = ""
 
-        if shouldIncludeListenPort {
-            if let tunnelListenPort = tunnelConfiguration.interface.listenPort {
-                wgSettings.append("listen_port=\(tunnelListenPort)\n")
-            } else if let currentListenPort = currentListenPort {
-                wgSettings.append("listen_port=\(currentListenPort)\n")
-            }
+        if let currentListenPort = currentListenPort {
+            wgSettings.append("listen_port=\(tunnelConfiguration.interface.listenPort ?? currentListenPort)\n")
         }
 
         for (index, peer) in tunnelConfiguration.peers.enumerated() {