]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use wgcfg key types
authorDavid Crawshaw <crawshaw@tailscale.com>
Sun, 23 Feb 2020 22:18:00 +0000 (17:18 -0500)
committerDavid Crawshaw <david@zentus.com>
Thu, 2 Apr 2020 04:53:10 +0000 (15:53 +1100)
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
device/cookie.go
device/cookie_test.go
device/device.go
device/device_test.go
device/noise-helpers.go
device/noise-protocol.go
device/noise-types.go [deleted file]
device/noise_test.go
device/peer.go
device/uapi.go

index f1341281c38427e33064e48cb6f19d08bb44701f..ec54f6189ef6b1f80e4ee5c37eefe08911f986cd 100644 (file)
@@ -13,6 +13,7 @@ import (
 
        "golang.org/x/crypto/blake2s"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type CookieChecker struct {
@@ -41,7 +42,7 @@ type CookieGenerator struct {
        }
 }
 
-func (st *CookieChecker) Init(pk NoisePublicKey) {
+func (st *CookieChecker) Init(pk wgcfg.Key) {
        st.Lock()
        defer st.Unlock()
 
@@ -171,7 +172,7 @@ func (st *CookieChecker) CreateReply(
        return reply, nil
 }
 
-func (st *CookieGenerator) Init(pk NoisePublicKey) {
+func (st *CookieGenerator) Init(pk wgcfg.Key) {
        st.Lock()
        defer st.Unlock()
 
index 79a6a86c58a05109d7c0a194fc6882340cefaed6..ef01d46bbd8ca6cfe9d9ad728e84fb1d7217157c 100644 (file)
@@ -7,6 +7,8 @@ package device
 
 import (
        "testing"
+
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 func TestCookieMAC1(t *testing.T) {
@@ -18,11 +20,11 @@ func TestCookieMAC1(t *testing.T) {
                checker   CookieChecker
        )
 
-       sk, err := newPrivateKey()
+       sk, err := wgcfg.NewPrivateKey()
        if err != nil {
                t.Fatal(err)
        }
-       pk := sk.publicKey()
+       pk := sk.Public()
 
        generator.Init(pk)
        checker.Init(pk)
index a9fedea86b3481bf9f581e6f850b467536de1efc..081d59fdc360c9ea211465bd61f5a9d4638846e7 100644 (file)
@@ -17,6 +17,7 @@ import (
        "golang.zx2c4.com/wireguard/ratelimiter"
        "golang.zx2c4.com/wireguard/rwcancel"
        "golang.zx2c4.com/wireguard/tun"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type Device struct {
@@ -46,13 +47,13 @@ type Device struct {
 
        staticIdentity struct {
                sync.RWMutex
-               privateKey NoisePrivateKey
-               publicKey  NoisePublicKey
+               privateKey wgcfg.PrivateKey
+               publicKey  wgcfg.Key
        }
 
        peers struct {
                sync.RWMutex
-               keyMap map[NoisePublicKey]*Peer
+               keyMap map[wgcfg.Key]*Peer
        }
 
        // unprotected / "self-synchronising resources"
@@ -96,7 +97,7 @@ type Device struct {
  *
  * Must hold device.peers.Mutex
  */
-func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) {
 
        // stop routing and processing of packets
 
@@ -200,13 +201,13 @@ func (device *Device) IsUnderLoad() bool {
        return until.After(now)
 }
 
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
+func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
        // lock required resources
 
        device.staticIdentity.Lock()
        defer device.staticIdentity.Unlock()
 
-       if sk.Equals(device.staticIdentity.privateKey) {
+       if sk.Equal(device.staticIdentity.privateKey) {
                return nil
        }
 
@@ -221,9 +222,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
        // remove peers with matching public keys
 
-       publicKey := sk.publicKey()
+       publicKey := sk.Public()
        for key, peer := range device.peers.keyMap {
-               if peer.handshake.remoteStatic.Equals(publicKey) {
+               if peer.handshake.remoteStatic.Equal(publicKey) {
                        unsafeRemovePeer(device, peer, key)
                }
        }
@@ -239,7 +240,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
        for _, peer := range device.peers.keyMap {
                handshake := &peer.handshake
-               handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+               handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic)
                expiredPeers = append(expiredPeers, peer)
        }
 
@@ -269,7 +270,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        }
        device.tun.mtu = int32(mtu)
 
-       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+       device.peers.keyMap = make(map[wgcfg.Key]*Peer)
 
        device.rate.limiter.Init()
        device.rate.underLoadUntil.Store(time.Time{})
@@ -317,14 +318,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        return device
 }
 
-func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+func (device *Device) LookupPeer(pk wgcfg.Key) *Peer {
        device.peers.RLock()
        defer device.peers.RUnlock()
 
        return device.peers.keyMap[pk]
 }
 
-func (device *Device) RemovePeer(key NoisePublicKey) {
+func (device *Device) RemovePeer(key wgcfg.Key) {
        device.peers.Lock()
        defer device.peers.Unlock()
        // stop peer and remove from routing
@@ -343,7 +344,7 @@ func (device *Device) RemoveAllPeers() {
                unsafeRemovePeer(device, peer, key)
        }
 
-       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+       device.peers.keyMap = make(map[wgcfg.Key]*Peer)
 }
 
 func (device *Device) FlushPacketQueues() {
index 87ecfc8735afef5709a02d088294be269f4a027a..925d2b114a357ee9c19b4f62a1b64fa2c89660d2 100644 (file)
@@ -14,6 +14,7 @@ import (
        "time"
 
        "golang.zx2c4.com/wireguard/tun/tuntest"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 func TestTwoDevicePing(t *testing.T) {
@@ -90,7 +91,7 @@ func assertEqual(t *testing.T, a, b []byte) {
 }
 
 func randDevice(t *testing.T) *Device {
-       sk, err := newPrivateKey()
+       sk, err := wgcfg.NewPrivateKey()
        if err != nil {
                t.Fatal(err)
        }
index f5e4b4b7074565168d30a245911cf513092cd0c9..ae52a7d1d60527f7e8e3c992364915bada507609 100644 (file)
@@ -7,12 +7,10 @@ package device
 
 import (
        "crypto/hmac"
-       "crypto/rand"
        "crypto/subtle"
        "hash"
 
        "golang.org/x/crypto/blake2s"
-       "golang.org/x/crypto/curve25519"
 )
 
 /* KDF related functions.
@@ -75,28 +73,3 @@ func setZero(arr []byte) {
                arr[i] = 0
        }
 }
-
-func (sk *NoisePrivateKey) clamp() {
-       sk[0] &= 248
-       sk[31] = (sk[31] & 127) | 64
-}
-
-func newPrivateKey() (sk NoisePrivateKey, err error) {
-       _, err = rand.Read(sk[:])
-       sk.clamp()
-       return
-}
-
-func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
-       apk := (*[NoisePublicKeySize]byte)(&pk)
-       ask := (*[NoisePrivateKeySize]byte)(sk)
-       curve25519.ScalarBaseMult(apk, ask)
-       return
-}
-
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
-       apk := (*[NoisePublicKeySize]byte)(&pk)
-       ask := (*[NoisePrivateKeySize]byte)(sk)
-       curve25519.ScalarMult(&ss, ask, apk)
-       return ss
-}
index 6dcc8313242674ef200a14d7abc11fb3d57d481b..5d9632c83b520c7c58fa91985cac177fdcbb1d75 100644 (file)
@@ -15,6 +15,7 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/crypto/poly1305"
        "golang.zx2c4.com/wireguard/tai64n"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type handshakeState int
@@ -84,8 +85,8 @@ const (
 type MessageInitiation struct {
        Type      uint32
        Sender    uint32
-       Ephemeral NoisePublicKey
-       Static    [NoisePublicKeySize + poly1305.TagSize]byte
+       Ephemeral wgcfg.Key
+       Static    [wgcfg.KeySize + poly1305.TagSize]byte
        Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
        MAC1      [blake2s.Size128]byte
        MAC2      [blake2s.Size128]byte
@@ -95,7 +96,7 @@ type MessageResponse struct {
        Type      uint32
        Sender    uint32
        Receiver  uint32
-       Ephemeral NoisePublicKey
+       Ephemeral wgcfg.Key
        Empty     [poly1305.TagSize]byte
        MAC1      [blake2s.Size128]byte
        MAC2      [blake2s.Size128]byte
@@ -118,15 +119,15 @@ type MessageCookieReply struct {
 type Handshake struct {
        state                     handshakeState
        mutex                     sync.RWMutex
-       hash                      [blake2s.Size]byte       // hash value
-       chainKey                  [blake2s.Size]byte       // chain key
-       presharedKey              NoiseSymmetricKey        // psk
-       localEphemeral            NoisePrivateKey          // ephemeral secret key
-       localIndex                uint32                   // used to clear hash-table
-       remoteIndex               uint32                   // index for sending
-       remoteStatic              NoisePublicKey           // long term key
-       remoteEphemeral           NoisePublicKey           // ephemeral public key
-       precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
+       hash                      [blake2s.Size]byte  // hash value
+       chainKey                  [blake2s.Size]byte  // chain key
+       presharedKey              wgcfg.SymmetricKey  // psk
+       localEphemeral            wgcfg.PrivateKey    // ephemeral secret key
+       localIndex                uint32              // used to clear hash-table
+       remoteIndex               uint32              // index for sending
+       remoteStatic              wgcfg.Key           // long term key
+       remoteEphemeral           wgcfg.Key           // ephemeral public key
+       precomputedStaticStatic   [wgcfg.KeySize]byte // precomputed shared secret
        lastTimestamp             tai64n.Timestamp
        lastInitiationConsumption time.Time
        lastSentHandshake         time.Time
@@ -188,7 +189,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        var err error
        handshake.hash = InitialHash
        handshake.chainKey = InitialChainKey
-       handshake.localEphemeral, err = newPrivateKey()
+       handshake.localEphemeral, err = wgcfg.NewPrivateKey()
        if err != nil {
                return nil, err
        }
@@ -197,14 +198,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 
        msg := MessageInitiation{
                Type:      MessageInitiationType,
-               Ephemeral: handshake.localEphemeral.publicKey(),
+               Ephemeral: handshake.localEphemeral.Public(),
        }
 
        handshake.mixKey(msg.Ephemeral[:])
        handshake.mixHash(msg.Ephemeral[:])
 
        // encrypt static key
-       ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+       ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
        if isZero(ss[:]) {
                return nil, errZeroECDHResult
        }
@@ -265,9 +266,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
 
        // decrypt static key
        var err error
-       var peerPK NoisePublicKey
+       var peerPK wgcfg.Key
        var key [chacha20poly1305.KeySize]byte
-       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+       ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
        if isZero(ss[:]) {
                return nil
        }
@@ -372,18 +373,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
        // create ephemeral key
 
-       handshake.localEphemeral, err = newPrivateKey()
+       handshake.localEphemeral, err = wgcfg.NewPrivateKey()
        if err != nil {
                return nil, err
        }
-       msg.Ephemeral = handshake.localEphemeral.publicKey()
+       msg.Ephemeral = handshake.localEphemeral.Public()
        handshake.mixHash(msg.Ephemeral[:])
        handshake.mixKey(msg.Ephemeral[:])
 
        func() {
-               ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+               ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral)
                handshake.mixKey(ss[:])
-               ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+               ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
                handshake.mixKey(ss[:])
        }()
 
@@ -453,13 +454,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
 
                func() {
-                       ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+                       ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
 
                func() {
-                       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+                       ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
diff --git a/device/noise-types.go b/device/noise-types.go
deleted file mode 100644 (file)
index a1976ff..0000000
+++ /dev/null
@@ -1,91 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
-       "crypto/subtle"
-       "encoding/hex"
-       "errors"
-
-       "golang.org/x/crypto/chacha20poly1305"
-)
-
-const (
-       NoisePublicKeySize  = 32
-       NoisePrivateKeySize = 32
-)
-
-type (
-       NoisePublicKey    [NoisePublicKeySize]byte
-       NoisePrivateKey   [NoisePrivateKeySize]byte
-       NoiseSymmetricKey [chacha20poly1305.KeySize]byte
-       NoiseNonce        uint64 // padded to 12-bytes
-)
-
-func loadExactHex(dst []byte, src string) error {
-       slice, err := hex.DecodeString(src)
-       if err != nil {
-               return err
-       }
-       if len(slice) != len(dst) {
-               return errors.New("hex string does not fit the slice")
-       }
-       copy(dst, slice)
-       return nil
-}
-
-func (key NoisePrivateKey) IsZero() bool {
-       var zero NoisePrivateKey
-       return key.Equals(zero)
-}
-
-func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
-       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
-}
-
-func (key *NoisePrivateKey) FromHex(src string) (err error) {
-       err = loadExactHex(key[:], src)
-       key.clamp()
-       return
-}
-
-func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
-       err = loadExactHex(key[:], src)
-       if key.IsZero() {
-               return
-       }
-       key.clamp()
-       return
-}
-
-func (key NoisePrivateKey) ToHex() string {
-       return hex.EncodeToString(key[:])
-}
-
-func (key *NoisePublicKey) FromHex(src string) error {
-       return loadExactHex(key[:], src)
-}
-
-func (key NoisePublicKey) ToHex() string {
-       return hex.EncodeToString(key[:])
-}
-
-func (key NoisePublicKey) IsZero() bool {
-       var zero NoisePublicKey
-       return key.Equals(zero)
-}
-
-func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
-       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
-}
-
-func (key *NoiseSymmetricKey) FromHex(src string) error {
-       return loadExactHex(key[:], src)
-}
-
-func (key NoiseSymmetricKey) ToHex() string {
-       return hex.EncodeToString(key[:])
-}
index 6ba3f2e6246e39a1293c69e5cf598fd35ad7a359..e431588a68de417ed2915a5e086a27dc520ba7ae 100644 (file)
@@ -11,24 +11,6 @@ import (
        "testing"
 )
 
-func TestCurveWrappers(t *testing.T) {
-       sk1, err := newPrivateKey()
-       assertNil(t, err)
-
-       sk2, err := newPrivateKey()
-       assertNil(t, err)
-
-       pk1 := sk1.publicKey()
-       pk2 := sk2.publicKey()
-
-       ss1 := sk1.sharedSecret(pk2)
-       ss2 := sk2.sharedSecret(pk1)
-
-       if ss1 != ss2 {
-               t.Fatal("Failed to compute shared secet")
-       }
-}
-
 func TestNoiseHandshake(t *testing.T) {
        dev1 := randDevice(t)
        dev2 := randDevice(t)
@@ -36,8 +18,14 @@ func TestNoiseHandshake(t *testing.T) {
        defer dev1.Close()
        defer dev2.Close()
 
-       peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
-       peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+       peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public())
+       if err != nil {
+               t.Fatal(err)
+       }
+       peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public())
+       if err != nil {
+               t.Fatal(err)
+       }
 
        assertEqual(
                t,
index a96f2612a1a5544b19224a24626f06089ad8c6f4..3ec625f8a79290ced139c5d93208a85b05dc240f 100644 (file)
@@ -14,6 +14,7 @@ import (
        "time"
 
        "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 const (
@@ -76,7 +77,8 @@ type Peer struct {
        cookieGenerator CookieGenerator
 }
 
-func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) {
+
        if device.isClosed.Get() {
                return nil, errors.New("device closed")
        }
@@ -116,7 +118,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        handshake := &peer.handshake
        handshake.mutex.Lock()
-       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk)
        handshake.remoteStatic = pk
        handshake.mutex.Unlock()
 
index 1671faa30d61fbd8e35d236f21425a59660f05bf..b266f4cce6b09deb38c313cea7a1bb01190ed342 100644 (file)
@@ -18,6 +18,7 @@ import (
 
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ipc"
+       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type IPCError struct {
@@ -54,7 +55,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
                // serialize device related values
 
                if !device.staticIdentity.privateKey.IsZero() {
-                       send("private_key=" + device.staticIdentity.privateKey.ToHex())
+                       send("private_key=" + device.staticIdentity.privateKey.HexString())
                }
 
                if device.net.port != 0 {
@@ -71,8 +72,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
                        peer.RLock()
                        defer peer.RUnlock()
 
-                       send("public_key=" + peer.handshake.remoteStatic.ToHex())
-                       send("preshared_key=" + peer.handshake.presharedKey.ToHex())
+                       send("public_key=" + peer.handshake.remoteStatic.HexString())
+                       send("preshared_key=" + peer.handshake.presharedKey.HexString())
                        send("protocol_version=1")
                        if peer.endpoint != nil {
                                send("endpoint=" + peer.endpoint.DstToString())
@@ -139,8 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
 
                        switch key {
                        case "private_key":
-                               var sk NoisePrivateKey
-                               err := sk.FromMaybeZeroHex(value)
+                               sk, err := wgcfg.ParsePrivateHexKey(value)
                                if err != nil {
                                        logError.Println("Failed to set private_key:", err)
                                        return &IPCError{ipc.IpcErrorInvalid}
@@ -221,8 +221,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                        switch key {
 
                        case "public_key":
-                               var publicKey NoisePublicKey
-                               err := publicKey.FromHex(value)
+                               publicKey, err := wgcfg.ParseHexKey(value)
                                if err != nil {
                                        logError.Println("Failed to get peer by public key:", err)
                                        return &IPCError{ipc.IpcErrorInvalid}
@@ -231,7 +230,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                                // ignore peer with public key of device
 
                                device.staticIdentity.RLock()
-                               dummy = device.staticIdentity.publicKey.Equals(publicKey)
+                               dummy = device.staticIdentity.publicKey.Equal(publicKey)
                                device.staticIdentity.RUnlock()
 
                                if dummy {
@@ -291,7 +290,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                                logDebug.Println(peer, "- UAPI: Updating preshared key")
 
                                peer.handshake.mutex.Lock()
-                               err := peer.handshake.presharedKey.FromHex(value)
+                               key, err := wgcfg.ParseSymmetricHexKey(value)
+                               peer.handshake.presharedKey = key
                                peer.handshake.mutex.Unlock()
 
                                if err != nil {