]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
WireGuardKit: Add WireGuardAdapter
authorAndrej Mihajlov <and@mullvad.net>
Wed, 4 Nov 2020 15:59:33 +0000 (16:59 +0100)
committerAndrej Mihajlov <and@mullvad.net>
Wed, 2 Dec 2020 10:08:09 +0000 (11:08 +0100)
Signed-off-by: Andrej Mihajlov <and@mullvad.net>
WireGuardKit/Sources/WireGuardKit/Array+ConcurrentMap.swift [new file with mode: 0644]
WireGuardKit/Sources/WireGuardKit/DNSResolver.swift
WireGuardKit/Sources/WireGuardKit/IPAddress+AddrInfo.swift [new file with mode: 0644]
WireGuardKit/Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift
WireGuardKit/Sources/WireGuardKit/WireGuardAdapter.swift [new file with mode: 0644]
WireGuardKit/Sources/WireGuardKit/WireGuardKit.swift

diff --git a/WireGuardKit/Sources/WireGuardKit/Array+ConcurrentMap.swift b/WireGuardKit/Sources/WireGuardKit/Array+ConcurrentMap.swift
new file mode 100644 (file)
index 0000000..8a7992a
--- /dev/null
@@ -0,0 +1,34 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2019 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+
+extension Array {
+
+    /// Returns an array containing the results of mapping the given closure over the sequence’s
+    /// elements concurrently.
+    ///
+    /// - Parameters:
+    ///   - queue: The queue for performing concurrent computations.
+    ///            If the given queue is serial, the values are mapped in a serial fashion.
+    ///            Pass `nil` to perform computations on the current queue.
+    ///   - transform: the block to perform concurrent computations over the given element.
+    /// - Returns: an array of concurrently computed values.
+    func concurrentMap<U>(queue: DispatchQueue?, _ transform: (Element) -> U) -> [U] {
+        var result = [U?](repeating: nil, count: self.count)
+        let resultQueue = DispatchQueue(label: "ConcurrentMapQueue")
+
+        let execute = queue?.sync ?? { $0() }
+
+        execute {
+            DispatchQueue.concurrentPerform(iterations: self.count) { (index) in
+                let value = transform(self[index])
+                resultQueue.sync {
+                    result[index] = value
+                }
+            }
+        }
+
+        return result.map { $0! }
+    }
+}
index cdb9665aef0902d7efc5c925f14fb8b3454e6eea..df19cb24311ddf2ca60159a353009848bbc16613 100644 (file)
 import Network
 import Foundation
 
-class DNSResolver {
+enum DNSResolver {}
 
-    static func isAllEndpointsAlreadyResolved(endpoints: [Endpoint?]) -> Bool {
-        for endpoint in endpoints {
-            guard let endpoint = endpoint else { continue }
-            if !endpoint.hasHostAsIPAddress() {
-                return false
-            }
-        }
-        return true
-    }
+extension DNSResolver {
 
-    static func resolveSync(endpoints: [Endpoint?]) -> [Endpoint?]? {
-        let dispatchGroup = DispatchGroup()
+    /// Concurrent queue used for DNS resolutions
+    private static let resolverQueue = DispatchQueue(label: "DNSResolverQueue", qos: .default, attributes: .concurrent)
 
-        if isAllEndpointsAlreadyResolved(endpoints: endpoints) {
-            return endpoints
+    static func resolveSync(endpoints: [Endpoint?]) -> [Result<Endpoint, DNSResolutionError>?] {
+        let isAllEndpointsAlreadyResolved = endpoints.allSatisfy({ (maybeEndpoint) -> Bool in
+            return maybeEndpoint?.hasHostAsIPAddress() ?? true
+        })
+
+        if isAllEndpointsAlreadyResolved {
+            return endpoints.map { (endpoint) in
+                return endpoint.map { .success($0) }
+            }
         }
 
-        var resolvedEndpoints: [Endpoint?] = Array(repeating: nil, count: endpoints.count)
-        for (index, endpoint) in endpoints.enumerated() {
-            guard let endpoint = endpoint else { continue }
+        return endpoints.concurrentMap(queue: resolverQueue) {
+            (endpoint) -> Result<Endpoint, DNSResolutionError>? in
+            guard let endpoint = endpoint else { return nil }
+
             if endpoint.hasHostAsIPAddress() {
-                resolvedEndpoints[index] = endpoint
+                return .success(endpoint)
             } else {
-                let workItem = DispatchWorkItem {
-                    resolvedEndpoints[index] = DNSResolver.resolveSync(endpoint: endpoint)
-                }
-                DispatchQueue.global(qos: .userInitiated).async(group: dispatchGroup, execute: workItem)
+                return Result { try DNSResolver.resolveSync(endpoint: endpoint) }
+                    .mapError { $0 as! DNSResolutionError }
             }
         }
+    }
 
-        dispatchGroup.wait() // TODO: Timeout?
-
-        var hostnamesWithDnsResolutionFailure = [String]()
-        assert(endpoints.count == resolvedEndpoints.count)
-        for tuple in zip(endpoints, resolvedEndpoints) {
-            let endpoint = tuple.0
-            let resolvedEndpoint = tuple.1
-            if let endpoint = endpoint {
-                if resolvedEndpoint == nil {
-                    guard let hostname = endpoint.hostname() else { fatalError() }
-                    hostnamesWithDnsResolutionFailure.append(hostname)
-                }
-            }
+    private static func resolveSync(endpoint: Endpoint) throws -> Endpoint {
+        guard case .name(let name, _) = endpoint.host else {
+            return endpoint
         }
-        if !hostnamesWithDnsResolutionFailure.isEmpty {
-            wg_log(.error, message: "DNS resolution failed for the following hostnames: \(hostnamesWithDnsResolutionFailure.joined(separator: ", "))")
-            return nil
+
+        var hints = addrinfo()
+        hints.ai_flags = AI_ALL // We set this to ALL so that we get v4 addresses even on DNS64 networks
+        hints.ai_family = AF_UNSPEC
+        hints.ai_socktype = SOCK_DGRAM
+        hints.ai_protocol = IPPROTO_UDP
+
+        var resultPointer: UnsafeMutablePointer<addrinfo>?
+        defer {
+            resultPointer.flatMap { freeaddrinfo($0) }
         }
-        return resolvedEndpoints
-    }
 
-    private static func resolveSync(endpoint: Endpoint) -> Endpoint? {
-        switch endpoint.host {
-        case .name(let name, _):
-            var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0))
-            var hints = addrinfo(
-                ai_flags: AI_ALL, // We set this to ALL so that we get v4 addresses even on DNS64 networks
-                ai_family: AF_UNSPEC,
-                ai_socktype: SOCK_DGRAM,
-                ai_protocol: IPPROTO_UDP,
-                ai_addrlen: 0,
-                ai_canonname: nil,
-                ai_addr: nil,
-                ai_next: nil)
-            if getaddrinfo(name, "\(endpoint.port)", &hints, &resultPointer) != 0 {
-                return nil
-            }
-            var next = resultPointer
-            var ipv4Address: IPv4Address?
-            var ipv6Address: IPv6Address?
-            while next != nil {
-                let result = next!.pointee
-                next = result.ai_next
-                if result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size {
-                    var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee
-                    ipv4Address = IPv4Address(Data(bytes: &sa4.sin_addr, count: MemoryLayout<in_addr>.size))
-                    break // If we found an IPv4 address, we can stop
-                } else if result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size {
-                    var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee
-                    ipv6Address = IPv6Address(Data(bytes: &sa6.sin6_addr, count: MemoryLayout<in6_addr>.size))
-                    continue // If we already have an IPv6 address, we can skip this one
-                }
-            }
-            freeaddrinfo(resultPointer)
+        let errorCode = getaddrinfo(name, "\(endpoint.port)", &hints, &resultPointer)
+        if errorCode != 0 {
+            throw DNSResolutionError(errorCode: errorCode, address: name)
+        }
 
-            // We prefer an IPv4 address over an IPv6 address
-            if let ipv4Address = ipv4Address {
-                return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port)
-            } else if let ipv6Address = ipv6Address {
-                return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port)
-            } else {
-                return nil
+        var ipv4Address: IPv4Address?
+        var ipv6Address: IPv6Address?
+
+        var next: UnsafeMutablePointer<addrinfo>? = resultPointer
+        let iterator = AnyIterator { () -> addrinfo? in
+            let result = next?.pointee
+            next = result?.ai_next
+            return result
+        }
+
+        for addrInfo in iterator {
+            if let maybeIpv4Address = IPv4Address(addrInfo: addrInfo) {
+                ipv4Address = maybeIpv4Address
+                break // If we found an IPv4 address, we can stop
+            } else if let maybeIpv6Address = IPv6Address(addrInfo: addrInfo) {
+                ipv6Address = maybeIpv6Address
+                continue // If we already have an IPv6 address, we can skip this one
             }
-        default:
-            return endpoint
+        }
+
+        // We prefer an IPv4 address over an IPv6 address
+        if let ipv4Address = ipv4Address {
+            return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port)
+        } else if let ipv6Address = ipv6Address {
+            return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port)
+        } else {
+            // Must never happen
+            fatalError()
         }
     }
 }
 
 extension Endpoint {
-    func withReresolvedIP() -> Endpoint {
+    func withReresolvedIP() throws -> Endpoint {
         #if os(iOS)
-        var ret = self
         let hostname: String
         switch host {
         case .name(let name, _):
@@ -121,36 +103,30 @@ extension Endpoint {
             fatalError()
         }
 
-        var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0))
-        var hints = addrinfo(
-            ai_flags: 0, // We set this to zero so that we actually resolve this using DNS64
-            ai_family: AF_UNSPEC,
-            ai_socktype: SOCK_DGRAM,
-            ai_protocol: IPPROTO_UDP,
-            ai_addrlen: 0,
-            ai_canonname: nil,
-            ai_addr: nil,
-            ai_next: nil)
-        if getaddrinfo(hostname, "\(port)", &hints, &resultPointer) != 0 || resultPointer == nil {
-            return ret
+        var hints = addrinfo()
+        hints.ai_family = AF_UNSPEC
+        hints.ai_socktype = SOCK_DGRAM
+        hints.ai_protocol = IPPROTO_UDP
+        hints.ai_flags = AI_DEFAULT
+
+        var result: UnsafeMutablePointer<addrinfo>?
+        defer {
+            result.flatMap { freeaddrinfo($0) }
         }
-        let result = resultPointer!.pointee
-        if result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size {
-            var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee
-            let addr = IPv4Address(Data(bytes: &sa4.sin_addr, count: MemoryLayout<in_addr>.size))
-            ret = Endpoint(host: .ipv4(addr!), port: port)
-        } else if result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size {
-            var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee
-            let addr = IPv6Address(Data(bytes: &sa6.sin6_addr, count: MemoryLayout<in6_addr>.size))
-            ret = Endpoint(host: .ipv6(addr!), port: port)
+
+        let errorCode = getaddrinfo(hostname, "\(self.port)", &hints, &result)
+        if errorCode != 0 {
+            throw DNSResolutionError(errorCode: errorCode, address: hostname)
         }
-        freeaddrinfo(resultPointer)
-        if ret.host != host {
-            wg_log(.debug, message: "DNS64: mapped \(host) to \(ret.host)")
+
+        let addrInfo = result!.pointee
+        if let ipv4Address = IPv4Address(addrInfo: addrInfo) {
+            return Endpoint(host: .ipv4(ipv4Address), port: port)
+        } else if let ipv6Address = IPv6Address(addrInfo: addrInfo) {
+            return Endpoint(host: .ipv6(ipv6Address), port: port)
         } else {
-            wg_log(.debug, message: "DNS64: mapped \(host) to itself.")
+            fatalError()
         }
-        return ret
         #elseif os(macOS)
         return self
         #else
@@ -158,3 +134,18 @@ extension Endpoint {
         #endif
     }
 }
+
+/// An error type describing DNS resolution error
+public struct DNSResolutionError: LocalizedError {
+    public let errorCode: Int32
+    public let address: String
+
+    init(errorCode: Int32, address: String) {
+        self.errorCode = errorCode
+        self.address = address
+    }
+
+    public var errorDescription: String? {
+        return String(cString: gai_strerror(errorCode))
+    }
+}
diff --git a/WireGuardKit/Sources/WireGuardKit/IPAddress+AddrInfo.swift b/WireGuardKit/Sources/WireGuardKit/IPAddress+AddrInfo.swift
new file mode 100644 (file)
index 0000000..d860077
--- /dev/null
@@ -0,0 +1,37 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2019 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import Network
+
+extension IPv4Address {
+    init?(addrInfo: addrinfo) {
+        guard addrInfo.ai_family == AF_INET else { return nil }
+
+        let addressData = addrInfo.ai_addr.withMemoryRebound(to: sockaddr_in.self, capacity: MemoryLayout<sockaddr_in>.size) { (ptr) -> Data in
+            return Data(bytes: &ptr.pointee.sin_addr, count: MemoryLayout<in_addr>.size)
+        }
+
+        if let ipAddress = IPv4Address(addressData) {
+            self = ipAddress
+        } else {
+            return nil
+        }
+    }
+}
+
+extension IPv6Address {
+    init?(addrInfo: addrinfo) {
+        guard addrInfo.ai_family == AF_INET6 else { return nil }
+
+        let addressData = addrInfo.ai_addr.withMemoryRebound(to: sockaddr_in6.self, capacity: MemoryLayout<sockaddr_in6>.size) { (ptr) -> Data in
+            return Data(bytes: &ptr.pointee.sin6_addr, count: MemoryLayout<in6_addr>.size)
+        }
+
+        if let ipAddress = IPv6Address(addressData) {
+            self = ipAddress
+        } else {
+            return nil
+        }
+    }
+}
index 5922b2c1b5c9620337808e03e7929c33dfbd2015..e6f0a1f159972da7ee828dbb352b5fbfa1aa6479 100644 (file)
@@ -19,7 +19,7 @@ class PacketTunnelSettingsGenerator {
         var wgSettings = ""
         for (index, peer) in tunnelConfiguration.peers.enumerated() {
             wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
-            if let endpoint = resolvedEndpoints[index]?.withReresolvedIP() {
+            if let endpoint = try? resolvedEndpoints[index]?.withReresolvedIP() {
                 if case .name(_, _) = endpoint.host { assert(false, "Endpoint is not resolved") }
                 wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
             }
@@ -42,7 +42,7 @@ class PacketTunnelSettingsGenerator {
             if let preSharedKey = peer.preSharedKey?.hexKey {
                 wgSettings.append("preshared_key=\(preSharedKey)\n")
             }
-            if let endpoint = resolvedEndpoints[index]?.withReresolvedIP() {
+            if let endpoint = try? resolvedEndpoints[index]?.withReresolvedIP() {
                 if case .name(_, _) = endpoint.host { assert(false, "Endpoint is not resolved") }
                 wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
             }
diff --git a/WireGuardKit/Sources/WireGuardKit/WireGuardAdapter.swift b/WireGuardKit/Sources/WireGuardKit/WireGuardAdapter.swift
new file mode 100644 (file)
index 0000000..ef7214d
--- /dev/null
@@ -0,0 +1,381 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2019 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import NetworkExtension
+import libwg_go
+
+public enum WireGuardAdapterError: Error {
+    /// Failure to locate socket descriptor.
+    case cannotLocateSocketDescriptor
+
+    /// Failure to perform an operation in such state
+    case invalidState
+
+    /// Failure to resolve endpoints
+    case dnsResolution([DNSResolutionError])
+
+    /// Failure to set network settings
+    case setNetworkSettings(Error)
+
+    /// Timeout when calling to set network settings
+    case setNetworkSettingsTimeout
+
+    /// Failure to start WireGuard backend
+    case startWireGuardBackend(Int32)
+}
+
+public class WireGuardAdapter {
+    public typealias LogHandler = (WireGuardLogLevel, String) -> Void
+
+    /// Network routes monitor.
+    private var networkMonitor: NWPathMonitor?
+
+    /// Packet tunnel provider.
+    private weak var packetTunnelProvider: NEPacketTunnelProvider?
+
+    /// Log handler closure.
+    private var logHandler: LogHandler?
+
+    /// WireGuard internal handle returned by `wgTurnOn` that's used to associate the calls
+    /// with the specific WireGuard tunnel.
+    private var wireguardHandle: Int32?
+
+    /// Private queue used to synchronize access to `WireGuardAdapter` members.
+    private let workQueue = DispatchQueue(label: "WireGuardAdapterWorkQueue")
+
+    /// Flag that tells if the adapter has already started.
+    private var isStarted = false
+
+    /// Packet tunnel settings generator.
+    private var settingsGenerator: PacketTunnelSettingsGenerator?
+
+    /// Tunnel device file descriptor.
+    private var tunnelFileDescriptor: Int32? {
+        return self.packetTunnelProvider?.packetFlow.value(forKeyPath: "socket.fileDescriptor") as? Int32
+    }
+
+    /// Returns a Wireguard version.
+    class var version: String {
+        return String(cString: wgVersion())
+    }
+
+    /// Returns the tunnel device interface name, or nil on error.
+    /// - Returns: String.
+    public var interfaceName: String? {
+        guard let tunnelFileDescriptor = self.tunnelFileDescriptor else { return nil }
+
+        var buffer = [UInt8](repeating: 0, count: Int(IFNAMSIZ))
+
+        return buffer.withUnsafeMutableBufferPointer { (mutableBufferPointer) in
+            guard let baseAddress = mutableBufferPointer.baseAddress else { return nil }
+
+            var ifnameSize = socklen_t(IFNAMSIZ)
+            let result = getsockopt(
+                tunnelFileDescriptor,
+                2 /* SYSPROTO_CONTROL */,
+                2 /* UTUN_OPT_IFNAME */,
+                baseAddress,
+                &ifnameSize)
+
+            if result == 0 {
+                return String(cString: baseAddress)
+            } else {
+                return nil
+            }
+        }
+    }
+
+    // MARK: - Initialization
+
+    /// Designated initializer.
+    /// - Parameter packetTunnelProvider: an instance of `NEPacketTunnelProvider`. Internally stored
+    ///   as a weak reference.
+    public init(with packetTunnelProvider: NEPacketTunnelProvider) {
+        self.packetTunnelProvider = packetTunnelProvider
+    }
+
+    deinit {
+        // Force deactivate logger to make sure that no further calls to the instance of this class
+        // can happen after deallocation.
+        deactivateLogHandler()
+
+        // Cancel network monitor
+        networkMonitor?.cancel()
+
+        // Shutdown the tunnel
+        if let handle = self.wireguardHandle {
+            wgTurnOff(handle)
+        }
+    }
+
+    // MARK: - Public methods
+
+    /// Returns a runtime configuration from WireGuard.
+    /// - Parameter completionHandler: completion handler.
+    public func getRuntimeConfiguration(completionHandler: @escaping (String?) -> Void) {
+        workQueue.async {
+            guard let handle = self.wireguardHandle else {
+                completionHandler(nil)
+                return
+            }
+
+            if let settings = wgGetConfig(handle) {
+                completionHandler(String(cString: settings))
+                free(settings)
+            } else {
+                completionHandler(nil)
+            }
+        }
+    }
+
+    /// Set log handler.
+    /// - Parameter logHandler: log handler closure
+    public func setLogHandler(_ logHandler: LogHandler?) {
+        workQueue.async {
+            self.logHandler = logHandler
+        }
+
+        if logHandler == nil {
+            deactivateLogHandler()
+        } else {
+            activateLogHandler()
+        }
+    }
+
+    /// Start the tunnel tunnel.
+    /// - Parameters:
+    ///   - tunnelConfiguration: tunnel configuration.
+    ///   - completionHandler: completion handler.
+    public func start(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
+        workQueue.async {
+            guard !self.isStarted else {
+                completionHandler(.invalidState)
+                return
+            }
+
+            guard let tunnelFileDescriptor = self.tunnelFileDescriptor else {
+                completionHandler(.cannotLocateSocketDescriptor)
+                return
+            }
+
+            #if os(macOS)
+            wgEnableRoaming(true)
+            #endif
+
+            let networkMonitor = NWPathMonitor()
+            networkMonitor.pathUpdateHandler = { [weak self] path in
+                self?.didReceivePathUpdate(path: path)
+            }
+
+            networkMonitor.start(queue: self.workQueue)
+            self.networkMonitor = networkMonitor
+
+            self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { (settingsGenerator, error) in
+                if let error = error {
+                    completionHandler(error)
+                } else {
+                    var returnError: WireGuardAdapterError?
+                    let handle = wgTurnOn(settingsGenerator!.uapiConfiguration(), tunnelFileDescriptor)
+
+                    if handle >= 0 {
+                        self.wireguardHandle = handle
+                        self.isStarted = true
+                    } else {
+                        returnError = .startWireGuardBackend(handle)
+                    }
+
+                    completionHandler(returnError)
+                }
+            }
+        }
+    }
+
+    /// Stop the tunnel.
+    /// - Parameter completionHandler: completion handler.
+    public func stop(completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
+        workQueue.async {
+            guard self.isStarted else {
+                completionHandler(.invalidState)
+                return
+            }
+
+            self.networkMonitor?.cancel()
+            self.networkMonitor = nil
+
+            if let handle = self.wireguardHandle {
+                wgTurnOff(handle)
+                self.wireguardHandle = nil
+            }
+
+            self.isStarted = false
+
+            completionHandler(nil)
+        }
+    }
+
+    /// Update runtime configuration.
+    /// - Parameters:
+    ///   - tunnelConfiguration: tunnel configuration.
+    ///   - completionHandler: completion handler.
+    public func update(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
+        workQueue.async {
+            guard self.isStarted else {
+                completionHandler(.invalidState)
+                return
+            }
+
+            // Tell the system that the tunnel is going to reconnect using new WireGuard
+            // configuration.
+            // This will broadcast the `NEVPNStatusDidChange` notification to the GUI process.
+            self.packetTunnelProvider?.reasserting = true
+
+            self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { (settingsGenerator, error) in
+                if let error = error {
+                    completionHandler(error)
+                } else {
+                    if let handle = self.wireguardHandle {
+                        wgSetConfig(handle, settingsGenerator!.uapiConfiguration())
+                    }
+                    completionHandler(nil)
+                }
+
+                self.packetTunnelProvider?.reasserting = false
+            }
+        }
+    }
+
+    // MARK: - Private methods
+
+    /// Install WireGuard log handler.
+    private func activateLogHandler() {
+        let context = Unmanaged.passUnretained(self).toOpaque()
+        wgSetLogger(context) { (context, logLevel, message) in
+            guard let context = context, let message = message else { return }
+
+            let unretainedSelf = Unmanaged<WireGuardAdapter>.fromOpaque(context)
+                .takeUnretainedValue()
+
+            let swiftString = String(cString: message).trimmingCharacters(in: .newlines)
+            let tunnelLogLevel = WireGuardLogLevel(rawValue: logLevel) ?? .debug
+
+            unretainedSelf.handleLogLine(level: tunnelLogLevel, message: swiftString)
+        }
+    }
+
+    /// Uninstall WireGuard log handler.
+    private func deactivateLogHandler() {
+        wgSetLogger(nil, nil)
+    }
+
+    /// Resolve endpoints and update network configuration.
+    /// - Parameters:
+    ///   - tunnelConfiguration: tunnel configuration
+    ///   - completionHandler: completion handler
+    private func updateNetworkSettings(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (PacketTunnelSettingsGenerator?, WireGuardAdapterError?) -> Void) {
+        let resolvedEndpoints: [Endpoint?]
+
+        let resolvePeersResult = Result { try self.resolvePeers(for: tunnelConfiguration) }
+            .mapError { $0 as! WireGuardAdapterError }
+
+        switch resolvePeersResult {
+        case .success(let endpoints):
+            resolvedEndpoints = endpoints
+        case .failure(let error):
+            completionHandler(nil, error)
+            return
+        }
+
+        let settingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints)
+        let networkSettings = settingsGenerator.generateNetworkSettings()
+        self.settingsGenerator = settingsGenerator
+
+        var systemError: Error?
+        let condition = NSCondition()
+
+        // Activate the condition
+        condition.lock()
+        defer { condition.unlock() }
+
+        self.packetTunnelProvider?.setTunnelNetworkSettings(networkSettings, completionHandler: { (error) in
+            systemError = error
+            condition.signal()
+        })
+
+        // Packet tunnel's `setTunnelNetworkSettings` times out in certain
+        // scenarios & never calls the given callback.
+        let setTunnelNetworkSettingsTimeout: TimeInterval = 5 // seconds
+
+        if condition.wait(until: Date().addingTimeInterval(setTunnelNetworkSettingsTimeout)) {
+            let returnError = systemError.map { WireGuardAdapterError.setNetworkSettings($0) }
+
+            completionHandler(settingsGenerator, returnError)
+        } else {
+            completionHandler(nil, .setNetworkSettingsTimeout)
+        }
+    }
+
+    /// Resolve peers of the given tunnel configuration.
+    /// - Parameter tunnelConfiguration: tunnel configuration.
+    /// - Throws: an error of type `WireGuardAdapterError`.
+    /// - Returns: The list of resolved endpoints.
+    private func resolvePeers(for tunnelConfiguration: TunnelConfiguration) throws -> [Endpoint?] {
+        let endpoints = tunnelConfiguration.peers.map { $0.endpoint }
+        let resolutionResults = DNSResolver.resolveSync(endpoints: endpoints)
+        let resolutionErrors = resolutionResults.compactMap { (result) -> DNSResolutionError? in
+            if case .failure(let error) = result {
+                return error
+            } else {
+                return nil
+            }
+        }
+        assert(endpoints.count == resolutionResults.count)
+        guard resolutionErrors.isEmpty else {
+            throw WireGuardAdapterError.dnsResolution(resolutionErrors)
+        }
+
+        let resolvedEndpoints = resolutionResults.map { (result) -> Endpoint? in
+            return try? result?.get()
+        }
+
+        return resolvedEndpoints
+    }
+
+    /// Private helper to pass the logs coming from WireGuard to
+    /// - Parameters:
+    ///   - level: log level
+    ///   - message: message
+    private func handleLogLine(level: WireGuardLogLevel, message: String) {
+        workQueue.async {
+            self.logHandler?(level, message)
+        }
+    }
+
+    /// Helper method used by network path monitor.
+    /// - Parameter path: new network path
+    private func didReceivePathUpdate(path: Network.NWPath) {
+        guard self.isStarted else { return }
+
+        if let handle = self.wireguardHandle {
+            self.handleLogLine(level: .debug, message: "Network change detected with \(path.status) route and interface order \(path.availableInterfaces)")
+
+            #if os(iOS)
+            if let settingsGenerator = self.settingsGenerator {
+                wgSetConfig(handle, settingsGenerator.endpointUapiConfiguration())
+            }
+
+            // TODO: dynamically turn on or off WireGuard backend when entering airplane mode
+            #endif
+
+            wgBumpSockets(handle)
+        }
+    }
+}
+
+/// A enum describing Wireguard log levels defined in `api-ios.go` from `wireguard-apple`
+/// repository.
+public enum WireGuardLogLevel: Int32 {
+    case debug = 0
+    case info = 1
+    case error = 2
+}
index 8d005a996ddfd39a51d32576ea48e8c5177e854b..7e5b2c93346090c9e0722a905b11a69aaaebee72 100644 (file)
@@ -2,8 +2,7 @@
 // Copyright © 2018-2019 WireGuard LLC. All Rights Reserved.
 
 import Foundation
-import libwg_go
 
 public func getWireGuardVersion() -> String {
-    return String(cString: wgVersion()!)
+    return WireGuardAdapter.version
 }