]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
Do not require NetworkExtension to know its own name
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 21 Dec 2018 21:05:47 +0000 (22:05 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 21 Dec 2018 21:05:47 +0000 (22:05 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
WireGuard/WireGuardNetworkExtension/ErrorNotifier.swift
WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
wireguard-go-bridge/src/api-ios.go
wireguard-go-bridge/wireguard.h

index a9bcc144727d747ea6764911c9ce66e611136191..ac1a6368072475c7c8268ba1e1aa4f9f2df66ea1 100644 (file)
@@ -8,8 +8,6 @@ class ErrorNotifier {
     let activationAttemptId: String?
     weak var tunnelProvider: NEPacketTunnelProvider?
 
-    var tunnelName: String?
-
     init(activationAttemptId: String?, tunnelProvider: NEPacketTunnelProvider) {
         self.activationAttemptId = activationAttemptId
         self.tunnelProvider = tunnelProvider
index 27a42c5d7b1e16b86340c720e06e3dd4de1bfc1a..5e994c0c82b4adb50b432ea5750a7e74e95acb8a 100644 (file)
@@ -37,10 +37,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
 
         configureLogger()
 
-        let tunnelName = tunnelConfiguration.interface.name
-        wg_log(.info, message: "Starting tunnel '\(tunnelName)' from the " + (activationAttemptId == nil ? "OS directly, rather than the app" : "app"))
-
-        errorNotifier.tunnelName = tunnelName
+        wg_log(.info, message: "Starting tunnel from the " + (activationAttemptId == nil ? "OS directly, rather than the app" : "app"))
 
         let endpoints = tunnelConfiguration.peers.map { $0.endpoint }
         guard let resolvedEndpoints = DNSResolver.resolveSync(endpoints: endpoints) else {
@@ -67,7 +64,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
         networkMonitor!.pathUpdateHandler = pathUpdate
         networkMonitor!.start(queue: DispatchQueue(label: "NetworkMonitor"))
 
-        let handle = withStringsAsGoStrings(tunnelConfiguration.interface.name, wireguardSettings) { return wgTurnOn($0.0, $0.1, fileDescriptor) }
+        let handle = wireguardSettings.withGoString { return wgTurnOn($0, fileDescriptor) }
         if handle < 0 {
             wg_log(.error, staticMessage: "Starting tunnel failed: Could not start WireGuard")
             errorNotifier.notify(PacketTunnelProviderError.couldNotStartWireGuard)
@@ -131,19 +128,20 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
         guard path.status == .satisfied else { return }
         wg_log(.debug, message: "Network change detected, re-establishing sockets and IPs: \(path.availableInterfaces)")
         let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: listenPort)
-        let err = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
+        let err = endpointString.withGoString { return wgSetConfig(handle, $0) }
         if err == -EADDRINUSE && listenPort != nil {
             let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: 0)
-            _ = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) })
+            _ = endpointString.withGoString { return wgSetConfig(handle, $0) }
+
         }
     }
 }
 
-// swiftlint:disable:next large_tuple identifier_name
-func withStringsAsGoStrings<R>(_ s1: String, _ s2: String? = nil, _ s3: String? = nil, _ s4: String? = nil, call: ((gostring_t, gostring_t, gostring_t, gostring_t)) -> R) -> R {
-    // swiftlint:disable:next large_tuple identifier_name
-    func helper(_ p1: UnsafePointer<Int8>?, _ p2: UnsafePointer<Int8>?, _ p3: UnsafePointer<Int8>?, _ p4: UnsafePointer<Int8>?, _ call: ((gostring_t, gostring_t, gostring_t, gostring_t)) -> R) -> R {
-        return call((gostring_t(p: p1, n: s1.utf8.count), gostring_t(p: p2, n: s2?.utf8.count ?? 0), gostring_t(p: p3, n: s3?.utf8.count ?? 0), gostring_t(p: p4, n: s4?.utf8.count ?? 0)))
+extension String {
+    func withGoString<R>(_ call: (gostring_t) -> R) -> R {
+        func helper(_ pointer: UnsafePointer<Int8>?, _ call: (gostring_t) -> R) -> R {
+            return call(gostring_t(p: pointer, n: utf8.count))
+        }
+        return helper(self, call)
     }
-    return helper(s1, s2, s3, s4, call)
 }
index 3d35d1e34aaaa64b258db2709ddeb126bb7c2930..902cfac9586fd14652e3e4735a952aa2b4530f1e 100644 (file)
@@ -32,15 +32,14 @@ var loggerFunc unsafe.Pointer
 var versionString *C.char
 
 type CLogger struct {
-       level         C.int
-       interfaceName string
+       level C.int
 }
 
 func (l *CLogger) Write(p []byte) (int, error) {
        if uintptr(loggerFunc) == 0 {
                return 0, errors.New("No logger initialized")
        }
-       message := C.CString(l.interfaceName + ": " + string(p))
+       message := C.CString(string(p))
        C.callLogger(loggerFunc, l.level, message)
        C.free(unsafe.Pointer(message))
        return len(p), nil
@@ -75,17 +74,13 @@ func wgSetLogger(loggerFn uintptr) {
 }
 
 //export wgTurnOn
-func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
-       interfaceName := string([]byte(ifnameRef))
-
+func wgTurnOn(settings string, tunFd int32) int32 {
        logger := &Logger{
-               Debug: log.New(&CLogger{level: 0, interfaceName: interfaceName}, "", 0),
-               Info:  log.New(&CLogger{level: 1, interfaceName: interfaceName}, "", 0),
-               Error: log.New(&CLogger{level: 2, interfaceName: interfaceName}, "", 0),
+               Debug: log.New(&CLogger{level: 0}, "", 0),
+               Info:  log.New(&CLogger{level: 1}, "", 0),
+               Error: log.New(&CLogger{level: 2}, "", 0),
        }
 
-       logger.Debug.Println("Debug log enabled")
-
        tun, _, err := tun.CreateTUNFromFD(int(tunFd))
        if err != nil {
                logger.Error.Println(err)
index fec352d57c9a07d3b6fc4f25c4f5e74c49128bf8..d7183c9793187f5ac5b9aa452a3a18dabb971838 100644 (file)
@@ -12,7 +12,7 @@
 typedef struct { const char *p; size_t n; } gostring_t;
 typedef void(*logger_fn_t)(int level, const char *msg);
 extern void wgSetLogger(logger_fn_t logger_fn);
-extern int wgTurnOn(gostring_t ifname, gostring_t settings, int32_t tun_fd);
+extern int wgTurnOn(gostring_t settings, int32_t tun_fd);
 extern void wgTurnOff(int handle);
 extern int64_t wgSetConfig(int handle, gostring_t settings);
 extern uint16_t wgGetListenPort(int handle);