]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
Rework DNS and routes in network extension
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 28 Dec 2018 18:34:31 +0000 (19:34 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 28 Dec 2018 18:38:03 +0000 (19:38 +0100)
The DNS resolver prior had useless comments, awful nesting, converted
bytes into strings and back into bytes, and generally made no sense.
That's been rewritten now.

But more fundumentally, this commit made the DNS resolver actually
accomplish its objective, by passing AI_ALL to it. It turns out, though,
that the Go library isn't actually using GAI in the way we need for
parsing IP addresses, so we actually need to do another round, this time
with hints flag as zero, so that we get the DNS64 address.

Additionally, since we're now binding sockets to interfaces, we can
entirely remove the excludedRoutes logic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
WireGuard/WireGuardNetworkExtension/DNSResolver.swift
WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift

index 0a76c209b323c58659d956e446214180637692cf..1ab5623ae1abf1c2fe4957688a67c508893020f9 100644 (file)
@@ -45,7 +45,6 @@ class DNSResolver {
             let resolvedEndpoint = tuple.1
             if let endpoint = endpoint {
                 if resolvedEndpoint == nil {
-                    // DNS resolution failed
                     guard let hostname = endpoint.hostname() else { fatalError() }
                     hostnamesWithDnsResolutionFailure.append(hostname)
                 }
@@ -57,81 +56,97 @@ class DNSResolver {
         }
         return resolvedEndpoints
     }
-}
-
-extension DNSResolver {
-    // Based on DNS resolution code by Jason Donenfeld <jason@zx2c4.com>
-    // in parse_endpoint() in src/tools/config.c in the WireGuard codebase
 
-    //swiftlint:disable:next cyclomatic_complexity
     private static func resolveSync(endpoint: Endpoint) -> Endpoint? {
         switch endpoint.host {
         case .name(let name, _):
             var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0))
-
-            // The endpoint is a hostname and needs DNS resolution
-            if addressInfo(for: name, port: endpoint.port, resultPointer: &resultPointer) == 0 {
-                // getaddrinfo succeeded
-                let ipv4Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET_ADDRSTRLEN))
-                let ipv6Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET6_ADDRSTRLEN))
-                var ipv4AddressString: String?
-                var ipv6AddressString: String?
-                while resultPointer != nil {
-                    let result = resultPointer!.pointee
-                    resultPointer = result.ai_next
-                    if result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size {
-                        var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee
-                        if inet_ntop(result.ai_family, &sa4.sin_addr, ipv4Buffer, socklen_t(INET_ADDRSTRLEN)) != nil {
-                            ipv4AddressString = String(cString: ipv4Buffer)
-                            // If we found an IPv4 address, we can stop
-                            break
-                        }
-                    } else if result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size {
-                        if ipv6AddressString != nil {
-                            // If we already have an IPv6 address, we can skip this one
-                            continue
-                        }
-                        var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee
-                        if inet_ntop(result.ai_family, &sa6.sin6_addr, ipv6Buffer, socklen_t(INET6_ADDRSTRLEN)) != nil {
-                            ipv6AddressString = String(cString: ipv6Buffer)
-                        }
-                    }
-                }
-                ipv4Buffer.deallocate()
-                ipv6Buffer.deallocate()
-                // We prefer an IPv4 address over an IPv6 address
-                if let ipv4AddressString = ipv4AddressString, let ipv4Address = IPv4Address(ipv4AddressString) {
-                    return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port)
-                } else if let ipv6AddressString = ipv6AddressString, let ipv6Address = IPv6Address(ipv6AddressString) {
-                    return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port)
-                } else {
-                    return nil
+            var hints = addrinfo(
+                ai_flags: AI_ALL, // We set this to ALL so that we get v4 addresses even on DNS64 networks
+                ai_family: AF_UNSPEC,
+                ai_socktype: SOCK_DGRAM,
+                ai_protocol: IPPROTO_UDP,
+                ai_addrlen: 0,
+                ai_canonname: nil,
+                ai_addr: nil,
+                ai_next: nil)
+            if getaddrinfo(name, "\(endpoint.port)", &hints, &resultPointer) != 0 {
+                return nil
+            }
+            var next = resultPointer
+            var ipv4Address: IPv4Address?
+            var ipv6Address: IPv6Address?
+            while next != nil {
+                let result = next!.pointee
+                next = result.ai_next
+                if result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size {
+                    var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee
+                    ipv4Address = IPv4Address(Data(bytes: &sa4.sin_addr, count: MemoryLayout<in_addr>.size))
+                    break // If we found an IPv4 address, we can stop
+                } else if result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size {
+                    var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee
+                    ipv6Address = IPv6Address(Data(bytes: &sa6.sin6_addr, count: MemoryLayout<in6_addr>.size))
+                    continue // If we already have an IPv6 address, we can skip this one
                 }
+            }
+            freeaddrinfo(resultPointer)
+
+            // We prefer an IPv4 address over an IPv6 address
+            if let ipv4Address = ipv4Address {
+                return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port)
+            } else if let ipv6Address = ipv6Address {
+                return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port)
             } else {
-                // getaddrinfo failed
                 return nil
             }
         default:
-            // The endpoint is already resolved
             return endpoint
         }
     }
+}
+
+extension Endpoint {
+    func withReresolvedIP() -> Endpoint {
+        var ret = self
+        let hostname: String
+        switch host {
+        case .name(let name, _):
+            hostname = name
+        case .ipv4(let address):
+            hostname = "\(address)"
+        case .ipv6(let address):
+            hostname = "\(address)"
+        }
 
-    private static func addressInfo(for name: String, port: NWEndpoint.Port, resultPointer: inout UnsafeMutablePointer<addrinfo>?) -> Int32 {
+        var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0))
         var hints = addrinfo(
-            ai_flags: 0,
+            ai_flags: 0, // We set this to zero so that we actually resolve this using DNS64
             ai_family: AF_UNSPEC,
-            ai_socktype: SOCK_DGRAM, // WireGuard is UDP-only
-            ai_protocol: IPPROTO_UDP, // WireGuard is UDP-only
+            ai_socktype: SOCK_DGRAM,
+            ai_protocol: IPPROTO_UDP,
             ai_addrlen: 0,
             ai_canonname: nil,
             ai_addr: nil,
             ai_next: nil)
-
-        return getaddrinfo(
-            name.cString(using: .utf8), // Hostname
-            "\(port)".cString(using: .utf8), // Port
-            &hints,
-            &resultPointer)
+        if getaddrinfo(hostname, "\(port)", &hints, &resultPointer) != 0 || resultPointer == nil {
+            return ret
+        }
+        let result = resultPointer!.pointee
+        if result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size {
+            var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee
+            let addr = IPv4Address(Data(bytes: &sa4.sin_addr, count: MemoryLayout<in_addr>.size))
+            ret = Endpoint(host: .ipv4(addr!), port: port)
+        } else if result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size {
+            var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee
+            let addr = IPv6Address(Data(bytes: &sa6.sin6_addr, count: MemoryLayout<in6_addr>.size))
+            ret = Endpoint(host: .ipv6(addr!), port: port)
+        }
+        freeaddrinfo(resultPointer)
+        if ret.host != host {
+            wg_log(.debug, message: "DNS64: mapped \(host) to \(ret.host)")
+        } else {
+            wg_log(.debug, message: "DNS64: mapped \(host) to itself.")
+        }
+        return ret
     }
 }
index 67b1f4d03fc5d0c9d238d9986dd04844ba69593f..b00f197727c0ea66219ae5f2212b402b07f9f9e2 100644 (file)
@@ -17,6 +17,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
         networkMonitor?.cancel()
     }
 
+    //swiftlint:disable:next function_body_length
     override func startTunnel(options: [String: NSObject]?, completionHandler startTunnelCompletionHandler: @escaping (Error?) -> Void) {
         let activationAttemptId = options?["activationAttemptId"] as? String
         let errorNotifier = ErrorNotifier(activationAttemptId: activationAttemptId)
@@ -65,6 +66,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
                 if getsockopt(fileDescriptor, 2 /* SYSPROTO_CONTROL */, 2 /* UTUN_OPT_IFNAME */, ifnamePtr, &ifnameSize) == 0 {
                     self.ifname = String(cString: ifnamePtr)
                 }
+                ifnamePtr.deallocate()
                 wg_log(.info, message: "Tunnel interface is \(self.ifname ?? "unknown")")
                 let handle = self.packetTunnelSettingsGenerator!.uapiConfiguration().withGoString { return wgTurnOn($0, fileDescriptor) }
                 if handle < 0 {
index 5946843ddf7eb36c85e84ebee35f8c0b99188c52..4fd84fcc4891be6043c6288306d00e6e99c47efa 100644 (file)
@@ -18,7 +18,7 @@ class PacketTunnelSettingsGenerator {
         var wgSettings = ""
         for (index, peer) in tunnelConfiguration.peers.enumerated() {
             wgSettings.append("public_key=\(peer.publicKey.hexEncodedString())\n")
-            if let endpoint = resolvedEndpoints[index] {
+            if let endpoint = resolvedEndpoints[index]?.withReresolvedIP() {
                 if case .name(_, _) = endpoint.host { assert(false, "Endpoint is not resolved") }
                 wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
             }
@@ -42,7 +42,7 @@ class PacketTunnelSettingsGenerator {
             if let preSharedKey = peer.preSharedKey {
                 wgSettings.append("preshared_key=\(preSharedKey.hexEncodedString())\n")
             }
-            if let endpoint = resolvedEndpoints[index] {
+            if let endpoint = resolvedEndpoints[index]?.withReresolvedIP() {
                 if case .name(_, _) = endpoint.host { assert(false, "Endpoint is not resolved") }
                 wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
             }
@@ -63,18 +63,7 @@ class PacketTunnelSettingsGenerator {
          * make sense. So, we fill it in with this placeholder, which is not
          * a valid IP address that will actually route over the Internet.
          */
-        var remoteAddress = "0.0.0.0"
-        let endpointsCompact = resolvedEndpoints.compactMap { $0 }
-        if endpointsCompact.count == 1 {
-            switch endpointsCompact.first!.host {
-            case .ipv4(let address):
-                remoteAddress = "\(address)"
-            case .ipv6(let address):
-                remoteAddress = "\(address)"
-            default:
-                break
-            }
-        }
+        let remoteAddress = "0.0.0.0"
 
         let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: remoteAddress)
 
@@ -93,16 +82,13 @@ class PacketTunnelSettingsGenerator {
 
         let (ipv4Routes, ipv6Routes) = routes()
         let (ipv4IncludedRoutes, ipv6IncludedRoutes) = includedRoutes()
-        let (ipv4ExcludedRoutes, ipv6ExcludedRoutes) = excludedRoutes()
 
         let ipv4Settings = NEIPv4Settings(addresses: ipv4Routes.map { $0.destinationAddress }, subnetMasks: ipv4Routes.map { $0.destinationSubnetMask })
         ipv4Settings.includedRoutes = ipv4IncludedRoutes
-        ipv4Settings.excludedRoutes = ipv4ExcludedRoutes
         networkSettings.ipv4Settings = ipv4Settings
 
         let ipv6Settings = NEIPv6Settings(addresses: ipv6Routes.map { $0.destinationAddress }, networkPrefixLengths: ipv6Routes.map { $0.destinationNetworkPrefixLength })
         ipv6Settings.includedRoutes = ipv6IncludedRoutes
-        ipv6Settings.excludedRoutes = ipv6ExcludedRoutes
         networkSettings.ipv6Settings = ipv6Settings
 
         return networkSettings
@@ -152,24 +138,6 @@ class PacketTunnelSettingsGenerator {
         }
         return (ipv4IncludedRoutes, ipv6IncludedRoutes)
     }
-
-    private func excludedRoutes() -> ([NEIPv4Route], [NEIPv6Route]) {
-        var ipv4ExcludedRoutes = [NEIPv4Route]()
-        var ipv6ExcludedRoutes = [NEIPv6Route]()
-        for endpoint in resolvedEndpoints {
-            guard let endpoint = endpoint else { continue }
-            switch endpoint.host {
-            case .ipv4(let address):
-                ipv4ExcludedRoutes.append(NEIPv4Route(destinationAddress: "\(address)", subnetMask: "255.255.255.255"))
-            case .ipv6(let address):
-                ipv6ExcludedRoutes.append(NEIPv6Route(destinationAddress: "\(address)", networkPrefixLength: NSNumber(value: UInt8(128))))
-            default:
-                fatalError()
-            }
-        }
-        return (ipv4ExcludedRoutes, ipv6ExcludedRoutes)
-    }
-
 }
 
 private extension Data {