]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
WireGuardKit: Conditionally turn on/off wireguard-go
authorAndrej Mihajlov <and@mullvad.net>
Tue, 1 Dec 2020 10:18:31 +0000 (11:18 +0100)
committerAndrej Mihajlov <and@mullvad.net>
Fri, 11 Dec 2020 10:15:22 +0000 (11:15 +0100)
Signed-off-by: Andrej Mihajlov <and@mullvad.net>
Sources/WireGuardKit/DNSResolver.swift
Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift
Sources/WireGuardKit/WireGuardAdapter.swift

index 5315c94ee22c0d1f629d7db460a672bfadd99ef3..7a0f2e916699a3aa50170da3c964c8eba86fc9f4 100644 (file)
@@ -109,7 +109,7 @@ extension Endpoint {
         hints.ai_family = AF_UNSPEC
         hints.ai_socktype = SOCK_DGRAM
         hints.ai_protocol = IPPROTO_UDP
-        hints.ai_flags = AI_DEFAULT
+        hints.ai_flags = 0 // We set this to zero so that we actually resolve this using DNS64
 
         var result: UnsafeMutablePointer<addrinfo>?
         defer {
index 9efe1fa054300423ceaad45285b64554dd6ae456..0ddc1b7da3864f0f4d7f0faf803285f72c12e709 100644 (file)
@@ -9,6 +9,9 @@ import NetworkExtension
 import WireGuardKitC
 #endif
 
+/// A type alias for `Result` type that holds a tuple with source and resolved endpoint.
+typealias EndpointResolutionResult = Result<(Endpoint, Endpoint), DNSResolutionError>
+
 class PacketTunnelSettingsGenerator {
     let tunnelConfiguration: TunnelConfiguration
     let resolvedEndpoints: [Endpoint?]
@@ -18,31 +21,27 @@ class PacketTunnelSettingsGenerator {
         self.resolvedEndpoints = resolvedEndpoints
     }
 
-    func endpointUapiConfiguration() -> (String, [DNSResolutionError]) {
-        var resolutionErrors = [DNSResolutionError]()
+    func endpointUapiConfiguration() -> (String, [EndpointResolutionResult?]) {
+        var resolutionResults = [EndpointResolutionResult?]()
         var wgSettings = ""
-        for (index, peer) in tunnelConfiguration.peers.enumerated() {
+
+        assert(tunnelConfiguration.peers.count == resolvedEndpoints.count)
+        for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) {
             wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
-            let result = Result { try resolvedEndpoints[index]?.withReresolvedIP() }
-                .mapError { error -> DNSResolutionError in
-                    // swiftlint:disable:next force_cast
-                    return error as! DNSResolutionError
-                }
 
-            switch result {
-            case .success(.some(let endpoint)):
-                if case .name = endpoint.host { assert(false, "Endpoint is not resolved") }
-                wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
-            case .success(.none):
-                break
-            case .failure(let error):
-                resolutionErrors.append(error)
+            let result = resolvedEndpoint.map(Self.reresolveEndpoint)
+            if case .success((_, let resolvedEndpoint)) = result {
+                if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") }
+                wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n")
             }
+            resolutionResults.append(result)
         }
-        return (wgSettings, resolutionErrors)
+
+        return (wgSettings, resolutionResults)
     }
 
-    func uapiConfiguration() -> String {
+    func uapiConfiguration() -> (String, [EndpointResolutionResult?]) {
+        var resolutionResults = [EndpointResolutionResult?]()
         var wgSettings = ""
         wgSettings.append("private_key=\(tunnelConfiguration.interface.privateKey.hexKey)\n")
         if let listenPort = tunnelConfiguration.interface.listenPort {
@@ -52,15 +51,19 @@ class PacketTunnelSettingsGenerator {
             wgSettings.append("replace_peers=true\n")
         }
         assert(tunnelConfiguration.peers.count == resolvedEndpoints.count)
-        for (index, peer) in tunnelConfiguration.peers.enumerated() {
+        for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) {
             wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
             if let preSharedKey = peer.preSharedKey?.hexKey {
                 wgSettings.append("preshared_key=\(preSharedKey)\n")
             }
-            if let endpoint = try? resolvedEndpoints[index]?.withReresolvedIP() {
-                if case .name = endpoint.host { assert(false, "Endpoint is not resolved") }
-                wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
+
+            let result = resolvedEndpoint.map(Self.reresolveEndpoint)
+            if case .success((_, let resolvedEndpoint)) = result {
+                if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") }
+                wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n")
             }
+            resolutionResults.append(result)
+
             let persistentKeepAlive = peer.persistentKeepAlive ?? 0
             wgSettings.append("persistent_keepalive_interval=\(persistentKeepAlive)\n")
             if !peer.allowedIPs.isEmpty {
@@ -68,7 +71,7 @@ class PacketTunnelSettingsGenerator {
                 peer.allowedIPs.forEach { wgSettings.append("allowed_ip=\($0.stringRepresentation)\n") }
             }
         }
-        return wgSettings
+        return (wgSettings, resolutionResults)
     }
 
     func generateNetworkSettings() -> NEPacketTunnelNetworkSettings {
@@ -163,4 +166,12 @@ class PacketTunnelSettingsGenerator {
         }
         return (ipv4IncludedRoutes, ipv6IncludedRoutes)
     }
+
+    private class func reresolveEndpoint(endpoint: Endpoint) -> EndpointResolutionResult {
+        return Result { (endpoint, try endpoint.withReresolvedIP()) }
+            .mapError { error -> DNSResolutionError in
+                // swiftlint:disable:next force_cast
+                return error as! DNSResolutionError
+            }
+    }
 }
index 113c06fa078e5ff2bb7f9aebf5bbb58fda87dba4..bf885c2ab91b81a8eb0ebe30c3afba64c1101681 100644 (file)
@@ -28,6 +28,18 @@ public enum WireGuardAdapterError: Error {
     case startWireGuardBackend(Int32)
 }
 
+/// Enum representing internal state of the `WireGuardAdapter`
+private enum State {
+    /// The tunnel is stopped
+    case stopped
+
+    /// The tunnel is up and running
+    case started(_ handle: Int32, _ settingsGenerator: PacketTunnelSettingsGenerator)
+
+    /// The tunnel is temporarily shutdown due to device going offline
+    case temporaryShutdown(_ settingsGenerator: PacketTunnelSettingsGenerator)
+}
+
 public class WireGuardAdapter {
     public typealias LogHandler = (WireGuardLogLevel, String) -> Void
 
@@ -40,15 +52,11 @@ public class WireGuardAdapter {
     /// Log handler closure.
     private let logHandler: LogHandler
 
-    /// WireGuard internal handle returned by `wgTurnOn` that's used to associate the calls
-    /// with the specific WireGuard tunnel.
-    private var wireguardHandle: Int32?
-
     /// Private queue used to synchronize access to `WireGuardAdapter` members.
     private let workQueue = DispatchQueue(label: "WireGuardAdapterWorkQueue")
 
-    /// Packet tunnel settings generator.
-    private var settingsGenerator: PacketTunnelSettingsGenerator?
+    /// Adapter state.
+    private var state: State = .stopped
 
     /// Tunnel device file descriptor.
     private var tunnelFileDescriptor: Int32? {
@@ -108,7 +116,7 @@ public class WireGuardAdapter {
         networkMonitor?.cancel()
 
         // Shutdown the tunnel
-        if let handle = self.wireguardHandle {
+        if case .started(let handle, _) = self.state {
             wgTurnOff(handle)
         }
     }
@@ -119,7 +127,7 @@ public class WireGuardAdapter {
     /// - Parameter completionHandler: completion handler.
     public func getRuntimeConfiguration(completionHandler: @escaping (String?) -> Void) {
         workQueue.async {
-            guard let handle = self.wireguardHandle else {
+            guard case .started(let handle, _) = self.state else {
                 completionHandler(nil)
                 return
             }
@@ -139,16 +147,11 @@ public class WireGuardAdapter {
     ///   - completionHandler: completion handler.
     public func start(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
         workQueue.async {
-            guard self.wireguardHandle == nil else {
+            guard case .stopped = self.state else {
                 completionHandler(.invalidState)
                 return
             }
 
-            guard let tunnelFileDescriptor = self.tunnelFileDescriptor else {
-                completionHandler(.cannotLocateTunnelFileDescriptor)
-                return
-            }
-
             #if os(macOS)
             wgEnableRoaming(true)
             #endif
@@ -157,25 +160,26 @@ public class WireGuardAdapter {
             networkMonitor.pathUpdateHandler = { [weak self] path in
                 self?.didReceivePathUpdate(path: path)
             }
-
             networkMonitor.start(queue: self.workQueue)
-            self.networkMonitor = networkMonitor
 
-            self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in
-                if let error = error {
-                    completionHandler(error)
-                } else {
-                    var returnError: WireGuardAdapterError?
-                    let handle = wgTurnOn(settingsGenerator!.uapiConfiguration(), tunnelFileDescriptor)
+            do {
+                let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration)
+                try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
 
-                    if handle >= 0 {
-                        self.wireguardHandle = handle
-                    } else {
-                        returnError = .startWireGuardBackend(handle)
-                    }
+                let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
+                self.logEndpointResolutionResults(resolutionResults)
 
-                    completionHandler(returnError)
-                }
+                self.state = .started(
+                    try self.startWireGuardBackend(wgConfig: wgConfig),
+                    settingsGenerator
+                )
+                self.networkMonitor = networkMonitor
+                completionHandler(nil)
+            } catch let error as WireGuardAdapterError {
+                networkMonitor.cancel()
+                completionHandler(error)
+            } catch {
+                fatalError()
             }
         }
     }
@@ -184,7 +188,14 @@ public class WireGuardAdapter {
     /// - Parameter completionHandler: completion handler.
     public func stop(completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
         workQueue.async {
-            guard let handle = self.wireguardHandle else {
+            switch self.state {
+            case .started(let handle, _):
+                wgTurnOff(handle)
+
+            case .temporaryShutdown:
+                break
+
+            case .stopped:
                 completionHandler(.invalidState)
                 return
             }
@@ -192,8 +203,7 @@ public class WireGuardAdapter {
             self.networkMonitor?.cancel()
             self.networkMonitor = nil
 
-            wgTurnOff(handle)
-            self.wireguardHandle = nil
+            self.state = .stopped
 
             completionHandler(nil)
         }
@@ -205,7 +215,7 @@ public class WireGuardAdapter {
     ///   - completionHandler: completion handler.
     public func update(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
         workQueue.async {
-            guard let handle = self.wireguardHandle else {
+            if case .stopped = self.state {
                 completionHandler(.invalidState)
                 return
             }
@@ -214,16 +224,35 @@ public class WireGuardAdapter {
             // configuration.
             // This will broadcast the `NEVPNStatusDidChange` notification to the GUI process.
             self.packetTunnelProvider?.reasserting = true
+            defer {
+                self.packetTunnelProvider?.reasserting = false
+            }
 
-            self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in
-                if let error = error {
-                    completionHandler(error)
-                } else {
-                    wgSetConfig(handle, settingsGenerator!.uapiConfiguration())
-                    completionHandler(nil)
+            do {
+                let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration)
+                try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
+
+                switch self.state {
+                case .started(let handle, _):
+                    let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
+                    self.logEndpointResolutionResults(resolutionResults)
+
+                    wgSetConfig(handle, wgConfig)
+
+                    self.state = .started(handle, settingsGenerator)
+
+                case .temporaryShutdown:
+                    self.state = .temporaryShutdown(settingsGenerator)
+
+                case .stopped:
+                    fatalError()
                 }
 
-                self.packetTunnelProvider?.reasserting = false
+                completionHandler(nil)
+            } catch let error as WireGuardAdapterError {
+                completionHandler(error)
+            } catch {
+                fatalError()
             }
         }
     }
@@ -246,30 +275,15 @@ public class WireGuardAdapter {
         }
     }
 
-    /// Resolve endpoints and update network configuration.
+    /// Set network tunnel configuration.
+    /// This method ensures that the call to `setTunnelNetworkSettings` does not time out, as in
+    /// certain scenarios the completion handler given to it may not be invoked by the system.
+    ///
     /// - Parameters:
-    ///   - tunnelConfiguration: tunnel configuration
-    ///   - completionHandler: completion handler
-    private func updateNetworkSettings(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (PacketTunnelSettingsGenerator?, WireGuardAdapterError?) -> Void) {
-        let resolvedEndpoints: [Endpoint?]
-
-        let resolvePeersResult = Result { try self.resolvePeers(for: tunnelConfiguration) }
-            .mapError { error -> WireGuardAdapterError in
-                // swiftlint:disable:next force_cast
-                return error as! WireGuardAdapterError
-            }
-
-        switch resolvePeersResult {
-        case .success(let endpoints):
-            resolvedEndpoints = endpoints
-        case .failure(let error):
-            completionHandler(nil, error)
-            return
-        }
-
-        let settingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints)
-        let networkSettings = settingsGenerator.generateNetworkSettings()
-
+    ///   - networkSettings: an instance of type `NEPacketTunnelNetworkSettings`.
+    /// - Throws: an error of type `WireGuardAdapterError`.
+    /// - Returns: `PacketTunnelSettingsGenerator`.
+    private func setNetworkSettings(_ networkSettings: NEPacketTunnelNetworkSettings) throws {
         var systemError: Error?
         let condition = NSCondition()
 
@@ -287,16 +301,11 @@ public class WireGuardAdapter {
         let setTunnelNetworkSettingsTimeout: TimeInterval = 5 // seconds
 
         if condition.wait(until: Date().addingTimeInterval(setTunnelNetworkSettingsTimeout)) {
-            let returnError = systemError.map { WireGuardAdapterError.setNetworkSettings($0) }
-
-            // Only assign `settingsGenerator` when `setTunnelNetworkSettings` succeeded.
-            if returnError == nil {
-                self.settingsGenerator = settingsGenerator
+            if let systemError = systemError {
+                throw WireGuardAdapterError.setNetworkSettings(systemError)
             }
-
-            completionHandler(settingsGenerator, returnError)
         } else {
-            completionHandler(nil, .setNetworkSettingsTimeout)
+            throw WireGuardAdapterError.setNetworkSettingsTimeout
         }
     }
 
@@ -327,24 +336,97 @@ public class WireGuardAdapter {
         return resolvedEndpoints
     }
 
+    /// Start WireGuard backend.
+    /// - Parameter wgConfig: WireGuard configuration
+    /// - Throws: an error of type `WireGuardAdapterError`
+    /// - Returns: tunnel handle
+    private func startWireGuardBackend(wgConfig: String) throws -> Int32 {
+        guard let tunnelFileDescriptor = self.tunnelFileDescriptor else {
+            throw WireGuardAdapterError.cannotLocateTunnelFileDescriptor
+        }
+
+        let handle = wgTurnOn(wgConfig, tunnelFileDescriptor)
+        if handle >= 0 {
+            return handle
+        } else {
+            throw WireGuardAdapterError.startWireGuardBackend(handle)
+        }
+    }
+
+    /// Resolves the hostnames in the given tunnel configuration and return settings generator.
+    /// - Parameter tunnelConfiguration: an instance of type `TunnelConfiguration`.
+    /// - Throws: an error of type `WireGuardAdapterError`.
+    /// - Returns: an instance of type `PacketTunnelSettingsGenerator`.
+    private func makeSettingsGenerator(with tunnelConfiguration: TunnelConfiguration) throws -> PacketTunnelSettingsGenerator {
+        return PacketTunnelSettingsGenerator(
+            tunnelConfiguration: tunnelConfiguration,
+            resolvedEndpoints: try self.resolvePeers(for: tunnelConfiguration)
+        )
+    }
+
+    /// Log DNS resolution results.
+    /// - Parameter resolutionErrors: an array of type `[DNSResolutionError]`.
+    private func logEndpointResolutionResults(_ resolutionResults: [EndpointResolutionResult?]) {
+        for case .some(let result) in resolutionResults {
+            switch result {
+            case .success((let sourceEndpoint, let resolvedEndpoint)):
+                if sourceEndpoint.host == resolvedEndpoint.host {
+                    self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to itself.")
+                } else {
+                    self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to \(resolvedEndpoint.host)")
+                }
+            case .failure(let resolutionError):
+                self.logHandler(.error, "Failed to resolve endpoint \(resolutionError.address): \(resolutionError.errorDescription ?? "(nil)")")
+            }
+        }
+    }
+
     /// Helper method used by network path monitor.
     /// - Parameter path: new network path
     private func didReceivePathUpdate(path: Network.NWPath) {
-        guard let handle = self.wireguardHandle else { return }
-
         self.logHandler(.debug, "Network change detected with \(path.status) route and interface order \(path.availableInterfaces)")
 
-        #if os(iOS)
-        if let settingsGenerator = self.settingsGenerator {
-            let (wgSettings, resolutionErrors) = settingsGenerator.endpointUapiConfiguration()
-            for error in resolutionErrors {
-                self.logHandler(.error, "Failed to re-resolve \(error.address): \(error.errorDescription ?? "(nil)")")
+        switch self.state {
+        case .started(let handle, let settingsGenerator):
+            if path.status.isSatisfiable {
+                #if os(iOS)
+                let (wgConfig, resolutionResults) = settingsGenerator.endpointUapiConfiguration()
+                self.logEndpointResolutionResults(resolutionResults)
+
+                wgSetConfig(handle, wgConfig)
+                #endif
+
+                wgBumpSockets(handle)
+            } else {
+                self.logHandler(.info, "Connectivity offline, pausing backend.")
+
+                self.state = .temporaryShutdown(settingsGenerator)
+                wgTurnOff(handle)
             }
-            wgSetConfig(handle, wgSettings)
-        }
-        #endif
 
-        wgBumpSockets(handle)
+        case .temporaryShutdown(let settingsGenerator):
+            guard path.status.isSatisfiable else { return }
+
+            self.logHandler(.info, "Connectivity online, resuming backend.")
+
+            do {
+                try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
+
+                let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
+                self.logEndpointResolutionResults(resolutionResults)
+
+                self.state = .started(
+                    try self.startWireGuardBackend(wgConfig: wgConfig),
+                    settingsGenerator
+                )
+            } catch {
+                self.logHandler(.error, "Failed to restart backend: \(error.localizedDescription)")
+            }
+
+        case .stopped:
+            // no-op
+            break
+        }
     }
 }
 
@@ -354,3 +436,17 @@ public enum WireGuardLogLevel: Int32 {
     case info = 1
     case error = 2
 }
+
+private extension Network.NWPath.Status {
+    /// Returns `true` if the path is potentially satisfiable.
+    var isSatisfiable: Bool {
+        switch self {
+        case .requiresConnection, .satisfied:
+            return true
+        case .unsatisfied:
+            return false
+        @unknown default:
+            return true
+        }
+    }
+}