]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
VPN: TunnelsManager should keep track of tunnel names to prevent duplicate names
authorRoopesh Chander <roop@roopc.net>
Wed, 31 Oct 2018 11:33:32 +0000 (17:03 +0530)
committerRoopesh Chander <roop@roopc.net>
Wed, 31 Oct 2018 20:17:53 +0000 (01:47 +0530)
Signed-off-by: Roopesh Chander <roop@roopc.net>
WireGuard/WireGuard/VPN/TunnelsManager.swift

index 2e28689f5ff0af8432c5ae8623f9aeccac03b7f7..a9cc1acd4a580ff13a4498fc6c6e35017e547e09 100644 (file)
@@ -30,21 +30,25 @@ class TunnelsManager {
     private var isModifyingTunnel: Bool = false
     private var isDeletingTunnel: Bool = false
 
+    private var tunnelNames: Set<String>
     private var currentTunnel: TunnelContainer?
     private var currentTunnelStatusObservationToken: AnyObject?
 
     init(tunnelProviders: [NETunnelProviderManager]) {
+        var tunnelNames: Set<String> = []
         var tunnels = tunnelProviders.map { TunnelContainer(tunnel: $0, index: 0) }
         tunnels.sort { $0.name < $1.name }
         var currentTunnel: TunnelContainer? = nil
         for i in 0 ..< tunnels.count {
             let tunnel = tunnels[i]
             tunnel.index = i
+            tunnelNames.insert(tunnel.name)
             if (tunnel.status != .inactive) {
                 currentTunnel = tunnel
             }
         }
         self.tunnels = tunnels
+        self.tunnelNames = tunnelNames
         if let currentTunnel = currentTunnel {
             setCurrentTunnel(tunnel: currentTunnel)
         }
@@ -60,6 +64,10 @@ class TunnelsManager {
         }
     }
 
+    func containsTunnel(named name: String) -> Bool {
+        return tunnelNames.contains(name)
+    }
+
     private func insertionIndexFor(tunnelName: String) -> Int {
         // Wishlist: Use binary search instead
         for i in 0 ..< tunnels.count {
@@ -71,6 +79,7 @@ class TunnelsManager {
     func add(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (TunnelContainer?, Error?) -> Void) {
         let tunnelName = tunnelConfiguration.interface.name
         assert(!tunnelName.isEmpty)
+        assert(!containsTunnel(named: tunnelName))
 
         isAddingTunnel = true
         let tunnelProviderManager = NETunnelProviderManager()
@@ -91,6 +100,7 @@ class TunnelsManager {
                     s.tunnels[i].index = s.tunnels[i].index + 1
                 }
                 s.tunnels.insert(tunnel, at: index)
+                s.tunnelNames.insert(tunnel.name)
                 s.delegate?.tunnelAdded(at: index)
                 completionHandler(tunnel, nil)
             }
@@ -126,7 +136,10 @@ class TunnelsManager {
 
         let tunnelProviderManager = tunnel.tunnelProvider
         let isNameChanged = (tunnelName != tunnelProviderManager.localizedDescription)
+        var oldName: String? = nil
         if (isNameChanged) {
+            assert(!containsTunnel(named: tunnelName))
+            oldName = tunnel.name
             tunnel.name = tunnelName
         }
         tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration)
@@ -142,6 +155,7 @@ class TunnelsManager {
             if let s = self {
                 if (isNameChanged) {
                     s.tunnels.remove(at: tunnel.index)
+                    s.tunnelNames.remove(oldName!)
                     for i in tunnel.index ..< s.tunnels.count {
                         s.tunnels[i].index = s.tunnels[i].index - 1
                     }
@@ -151,6 +165,7 @@ class TunnelsManager {
                         s.tunnels[i].index = s.tunnels[i].index + 1
                     }
                     s.tunnels.insert(tunnel, at: index)
+                    s.tunnelNames.insert(tunnel.name)
                     s.delegate?.tunnelsChanged()
                 } else {
                     s.delegate?.tunnelModified(at: tunnel.index)
@@ -163,6 +178,7 @@ class TunnelsManager {
     func remove(tunnel: TunnelContainer, completionHandler: @escaping (Error?) -> Void) {
         let tunnelProviderManager = tunnel.tunnelProvider
         let tunnelIndex = tunnel.index
+        let tunnelName = tunnel.name
 
         isDeletingTunnel = true
 
@@ -177,6 +193,7 @@ class TunnelsManager {
                     s.tunnels[i].index = s.tunnels[i].index + 1
                 }
                 s.tunnels.remove(at: tunnelIndex)
+                s.tunnelNames.remove(tunnelName)
                 s.delegate?.tunnelRemoved(at: tunnelIndex)
             }
             completionHandler(nil)