]> git.ipfire.org Git - thirdparty/wireguard-apple.git/commitdiff
Kit: rework keys implementation
authorAndrej Mihajlov <and@mullvad.net>
Thu, 5 May 2022 09:03:19 +0000 (11:03 +0200)
committerAndrej Mihajlov <and@mullvad.net>
Fri, 6 May 2022 11:31:28 +0000 (13:31 +0200)
Signed-off-by: Andrej Mihajlov <and@mullvad.net>
Sources/WireGuardKit/PrivateKey.swift

index f98f41d0ace9ea026f7640d20840746ff4ce8f15..aa63e851aec54f2fc03892c8a113e745d4e168ff 100644 (file)
@@ -7,8 +7,32 @@ import Foundation
 import WireGuardKitC
 #endif
 
-/// The class describing a private key used by WireGuard.
-public class PrivateKey: BaseKey {
+/// Umbrella protocol for all kinds of keys.
+public protocol WireGuardKey: RawRepresentable, Hashable where RawValue == Data {}
+
+/// Class describing a private key used by WireGuard.
+public final class PrivateKey: WireGuardKey {
+    public let rawValue: Data
+
+    /// Initialize the key with existing raw representation
+    public init?(rawValue: Data) {
+        if rawValue.count == WG_KEY_LEN {
+            self.rawValue = rawValue
+        } else {
+            return nil
+        }
+    }
+
+    /// Initialize new private key
+    convenience public init() {
+        var privateKeyData = Data(repeating: 0, count: Int(WG_KEY_LEN))
+        privateKeyData.withUnsafeMutableBytes { (rawBufferPointer: UnsafeMutableRawBufferPointer) in
+            let privateKeyBytes = rawBufferPointer.baseAddress!.assumingMemoryBound(to: UInt8.self)
+            curve25519_generate_private_key(privateKeyBytes)
+        }
+        self.init(rawValue: privateKeyData)!
+    }
+
     /// Derived public key
     public var publicKey: PublicKey {
         return rawValue.withUnsafeBytes { (privateKeyBufferPointer: UnsafeRawBufferPointer) -> PublicKey in
@@ -23,29 +47,38 @@ public class PrivateKey: BaseKey {
             return PublicKey(rawValue: publicKeyData)!
         }
     }
+}
 
-    /// Initialize new private key
-    convenience public init() {
-        var privateKeyData = Data(repeating: 0, count: Int(WG_KEY_LEN))
-        privateKeyData.withUnsafeMutableBytes { (rawBufferPointer: UnsafeMutableRawBufferPointer) in
-            let privateKeyBytes = rawBufferPointer.baseAddress!.assumingMemoryBound(to: UInt8.self)
-            curve25519_generate_private_key(privateKeyBytes)
+/// Class describing a public key used by WireGuard.
+public final class PublicKey: WireGuardKey {
+    public let rawValue: Data
+
+    /// Initialize the key with existing raw representation
+    public init?(rawValue: Data) {
+        if rawValue.count == WG_KEY_LEN {
+            self.rawValue = rawValue
+        } else {
+            return nil
         }
-        self.init(rawValue: privateKeyData)!
     }
 }
 
-/// The class describing a public key used by WireGuard.
-public class PublicKey: BaseKey {}
-
-/// The class describing a pre-shared key used by WireGuard.
-public class PreSharedKey: BaseKey {}
-
-/// The base key implementation. Should not be used directly.
-public class BaseKey: RawRepresentable, Equatable, Hashable {
-    /// Raw key representation
+/// Class describing a pre-shared key used by WireGuard.
+public final class PreSharedKey: WireGuardKey {
     public let rawValue: Data
 
+    /// Initialize the key with existing raw representation
+    public init?(rawValue: Data) {
+        if rawValue.count == WG_KEY_LEN {
+            self.rawValue = rawValue
+        } else {
+            return nil
+        }
+    }
+}
+
+// Default implementation
+extension WireGuardKey {
     /// Hex encoded representation
     public var hexKey: String {
         return rawValue.withUnsafeBytes { (rawBufferPointer: UnsafeRawBufferPointer) -> String in
@@ -66,17 +99,8 @@ public class BaseKey: RawRepresentable, Equatable, Hashable {
         }
     }
 
-    /// Initialize the key with existing raw representation
-    required public init?(rawValue: Data) {
-        if rawValue.count == WG_KEY_LEN {
-            self.rawValue = rawValue
-        } else {
-            return nil
-        }
-    }
-
     /// Initialize the key with hex representation
-    public convenience init?(hexKey: String) {
+    public init?(hexKey: String) {
         var bytes = Data(repeating: 0, count: Int(WG_KEY_LEN))
         let success = bytes.withUnsafeMutableBytes { (bufferPointer: UnsafeMutableRawBufferPointer) -> Bool in
             return key_from_hex(bufferPointer.baseAddress!.assumingMemoryBound(to: UInt8.self), hexKey)
@@ -89,7 +113,7 @@ public class BaseKey: RawRepresentable, Equatable, Hashable {
     }
 
     /// Initialize the key with base64 representation
-    public convenience init?(base64Key: String) {
+    public init?(base64Key: String) {
         var bytes = Data(repeating: 0, count: Int(WG_KEY_LEN))
         let success = bytes.withUnsafeMutableBytes { (bufferPointer: UnsafeMutableRawBufferPointer) -> Bool in
             return key_from_base64(bufferPointer.baseAddress!.assumingMemoryBound(to: UInt8.self), base64Key)
@@ -101,7 +125,9 @@ public class BaseKey: RawRepresentable, Equatable, Hashable {
         }
     }
 
-    public static func == (lhs: BaseKey, rhs: BaseKey) -> Bool {
+    // MARK: - Equatable
+
+    public static func == (lhs: Self, rhs: Self) -> Bool {
         return lhs.rawValue.withUnsafeBytes { (lhsBytes: UnsafeRawBufferPointer) -> Bool in
             return rhs.rawValue.withUnsafeBytes { (rhsBytes: UnsafeRawBufferPointer) -> Bool in
                 return key_eq(