]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Revert "device: use wgcfg key types"
authorDavid Crawshaw <crawshaw@tailscale.com>
Tue, 7 Apr 2020 05:52:04 +0000 (15:52 +1000)
committerDavid Crawshaw <crawshaw@tailscale.com>
Tue, 7 Apr 2020 05:52:41 +0000 (15:52 +1000)
More cleanup work of wgcfg to do before bringing this in.

This reverts commit 83ca9b47b63b4d07630c4d579faf1111e42537d3.

device/cookie.go
device/cookie_test.go
device/device.go
device/device_test.go
device/noise-helpers.go
device/noise-protocol.go
device/noise-types.go [new file with mode: 0644]
device/noise_test.go
device/peer.go
device/uapi.go

index ec54f6189ef6b1f80e4ee5c37eefe08911f986cd..f1341281c38427e33064e48cb6f19d08bb44701f 100644 (file)
@@ -13,7 +13,6 @@ import (
 
        "golang.org/x/crypto/blake2s"
        "golang.org/x/crypto/chacha20poly1305"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type CookieChecker struct {
@@ -42,7 +41,7 @@ type CookieGenerator struct {
        }
 }
 
-func (st *CookieChecker) Init(pk wgcfg.Key) {
+func (st *CookieChecker) Init(pk NoisePublicKey) {
        st.Lock()
        defer st.Unlock()
 
@@ -172,7 +171,7 @@ func (st *CookieChecker) CreateReply(
        return reply, nil
 }
 
-func (st *CookieGenerator) Init(pk wgcfg.Key) {
+func (st *CookieGenerator) Init(pk NoisePublicKey) {
        st.Lock()
        defer st.Unlock()
 
index ef01d46bbd8ca6cfe9d9ad728e84fb1d7217157c..79a6a86c58a05109d7c0a194fc6882340cefaed6 100644 (file)
@@ -7,8 +7,6 @@ package device
 
 import (
        "testing"
-
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 func TestCookieMAC1(t *testing.T) {
@@ -20,11 +18,11 @@ func TestCookieMAC1(t *testing.T) {
                checker   CookieChecker
        )
 
-       sk, err := wgcfg.NewPrivateKey()
+       sk, err := newPrivateKey()
        if err != nil {
                t.Fatal(err)
        }
-       pk := sk.Public()
+       pk := sk.publicKey()
 
        generator.Init(pk)
        checker.Init(pk)
index 081d59fdc360c9ea211465bd61f5a9d4638846e7..a9fedea86b3481bf9f581e6f850b467536de1efc 100644 (file)
@@ -17,7 +17,6 @@ import (
        "golang.zx2c4.com/wireguard/ratelimiter"
        "golang.zx2c4.com/wireguard/rwcancel"
        "golang.zx2c4.com/wireguard/tun"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type Device struct {
@@ -47,13 +46,13 @@ type Device struct {
 
        staticIdentity struct {
                sync.RWMutex
-               privateKey wgcfg.PrivateKey
-               publicKey  wgcfg.Key
+               privateKey NoisePrivateKey
+               publicKey  NoisePublicKey
        }
 
        peers struct {
                sync.RWMutex
-               keyMap map[wgcfg.Key]*Peer
+               keyMap map[NoisePublicKey]*Peer
        }
 
        // unprotected / "self-synchronising resources"
@@ -97,7 +96,7 @@ type Device struct {
  *
  * Must hold device.peers.Mutex
  */
-func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) {
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
 
        // stop routing and processing of packets
 
@@ -201,13 +200,13 @@ func (device *Device) IsUnderLoad() bool {
        return until.After(now)
 }
 
-func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        // lock required resources
 
        device.staticIdentity.Lock()
        defer device.staticIdentity.Unlock()
 
-       if sk.Equal(device.staticIdentity.privateKey) {
+       if sk.Equals(device.staticIdentity.privateKey) {
                return nil
        }
 
@@ -222,9 +221,9 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
 
        // remove peers with matching public keys
 
-       publicKey := sk.Public()
+       publicKey := sk.publicKey()
        for key, peer := range device.peers.keyMap {
-               if peer.handshake.remoteStatic.Equal(publicKey) {
+               if peer.handshake.remoteStatic.Equals(publicKey) {
                        unsafeRemovePeer(device, peer, key)
                }
        }
@@ -240,7 +239,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
        expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
        for _, peer := range device.peers.keyMap {
                handshake := &peer.handshake
-               handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic)
+               handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
                expiredPeers = append(expiredPeers, peer)
        }
 
@@ -270,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        }
        device.tun.mtu = int32(mtu)
 
-       device.peers.keyMap = make(map[wgcfg.Key]*Peer)
+       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 
        device.rate.limiter.Init()
        device.rate.underLoadUntil.Store(time.Time{})
@@ -318,14 +317,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        return device
 }
 
-func (device *Device) LookupPeer(pk wgcfg.Key) *Peer {
+func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
        device.peers.RLock()
        defer device.peers.RUnlock()
 
        return device.peers.keyMap[pk]
 }
 
-func (device *Device) RemovePeer(key wgcfg.Key) {
+func (device *Device) RemovePeer(key NoisePublicKey) {
        device.peers.Lock()
        defer device.peers.Unlock()
        // stop peer and remove from routing
@@ -344,7 +343,7 @@ func (device *Device) RemoveAllPeers() {
                unsafeRemovePeer(device, peer, key)
        }
 
-       device.peers.keyMap = make(map[wgcfg.Key]*Peer)
+       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 }
 
 func (device *Device) FlushPacketQueues() {
index 925d2b114a357ee9c19b4f62a1b64fa2c89660d2..87ecfc8735afef5709a02d088294be269f4a027a 100644 (file)
@@ -14,7 +14,6 @@ import (
        "time"
 
        "golang.zx2c4.com/wireguard/tun/tuntest"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 func TestTwoDevicePing(t *testing.T) {
@@ -91,7 +90,7 @@ func assertEqual(t *testing.T, a, b []byte) {
 }
 
 func randDevice(t *testing.T) *Device {
-       sk, err := wgcfg.NewPrivateKey()
+       sk, err := newPrivateKey()
        if err != nil {
                t.Fatal(err)
        }
index ae52a7d1d60527f7e8e3c992364915bada507609..f5e4b4b7074565168d30a245911cf513092cd0c9 100644 (file)
@@ -7,10 +7,12 @@ package device
 
 import (
        "crypto/hmac"
+       "crypto/rand"
        "crypto/subtle"
        "hash"
 
        "golang.org/x/crypto/blake2s"
+       "golang.org/x/crypto/curve25519"
 )
 
 /* KDF related functions.
@@ -73,3 +75,28 @@ func setZero(arr []byte) {
                arr[i] = 0
        }
 }
+
+func (sk *NoisePrivateKey) clamp() {
+       sk[0] &= 248
+       sk[31] = (sk[31] & 127) | 64
+}
+
+func newPrivateKey() (sk NoisePrivateKey, err error) {
+       _, err = rand.Read(sk[:])
+       sk.clamp()
+       return
+}
+
+func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
+       apk := (*[NoisePublicKeySize]byte)(&pk)
+       ask := (*[NoisePrivateKeySize]byte)(sk)
+       curve25519.ScalarBaseMult(apk, ask)
+       return
+}
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+       apk := (*[NoisePublicKeySize]byte)(&pk)
+       ask := (*[NoisePrivateKeySize]byte)(sk)
+       curve25519.ScalarMult(&ss, ask, apk)
+       return ss
+}
index 3ce7839eda2431f3d9bdc9036ec7fd328d8ab31d..03b872ba924c979d09ee58fe8e60562f47a7d4a0 100644 (file)
@@ -15,7 +15,6 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/crypto/poly1305"
        "golang.zx2c4.com/wireguard/tai64n"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type handshakeState int
@@ -85,8 +84,8 @@ const (
 type MessageInitiation struct {
        Type      uint32
        Sender    uint32
-       Ephemeral wgcfg.Key
-       Static    [wgcfg.KeySize + poly1305.TagSize]byte
+       Ephemeral NoisePublicKey
+       Static    [NoisePublicKeySize + poly1305.TagSize]byte
        Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
        MAC1      [blake2s.Size128]byte
        MAC2      [blake2s.Size128]byte
@@ -96,7 +95,7 @@ type MessageResponse struct {
        Type      uint32
        Sender    uint32
        Receiver  uint32
-       Ephemeral wgcfg.Key
+       Ephemeral NoisePublicKey
        Empty     [poly1305.TagSize]byte
        MAC1      [blake2s.Size128]byte
        MAC2      [blake2s.Size128]byte
@@ -119,15 +118,15 @@ type MessageCookieReply struct {
 type Handshake struct {
        state                     handshakeState
        mutex                     sync.RWMutex
-       hash                      [blake2s.Size]byte  // hash value
-       chainKey                  [blake2s.Size]byte  // chain key
-       presharedKey              wgcfg.SymmetricKey  // psk
-       localEphemeral            wgcfg.PrivateKey    // ephemeral secret key
-       localIndex                uint32              // used to clear hash-table
-       remoteIndex               uint32              // index for sending
-       remoteStatic              wgcfg.Key           // long term key
-       remoteEphemeral           wgcfg.Key           // ephemeral public key
-       precomputedStaticStatic   [wgcfg.KeySize]byte // precomputed shared secret
+       hash                      [blake2s.Size]byte       // hash value
+       chainKey                  [blake2s.Size]byte       // chain key
+       presharedKey              NoiseSymmetricKey        // psk
+       localEphemeral            NoisePrivateKey          // ephemeral secret key
+       localIndex                uint32                   // used to clear hash-table
+       remoteIndex               uint32                   // index for sending
+       remoteStatic              NoisePublicKey           // long term key
+       remoteEphemeral           NoisePublicKey           // ephemeral public key
+       precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
        lastTimestamp             tai64n.Timestamp
        lastInitiationConsumption time.Time
        lastSentHandshake         time.Time
@@ -189,7 +188,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        var err error
        handshake.hash = InitialHash
        handshake.chainKey = InitialChainKey
-       handshake.localEphemeral, err = wgcfg.NewPrivateKey()
+       handshake.localEphemeral, err = newPrivateKey()
        if err != nil {
                return nil, err
        }
@@ -198,14 +197,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 
        msg := MessageInitiation{
                Type:      MessageInitiationType,
-               Ephemeral: handshake.localEphemeral.Public(),
+               Ephemeral: handshake.localEphemeral.publicKey(),
        }
 
        handshake.mixKey(msg.Ephemeral[:])
        handshake.mixHash(msg.Ephemeral[:])
 
        // encrypt static key
-       ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
+       ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
        if isZero(ss[:]) {
                return nil, errZeroECDHResult
        }
@@ -266,9 +265,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
 
        // decrypt static key
        var err error
-       var peerPK wgcfg.Key
+       var peerPK NoisePublicKey
        var key [chacha20poly1305.KeySize]byte
-       ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
+       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
        if isZero(ss[:]) {
                return nil
        }
@@ -377,18 +376,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
        // create ephemeral key
 
-       handshake.localEphemeral, err = wgcfg.NewPrivateKey()
+       handshake.localEphemeral, err = newPrivateKey()
        if err != nil {
                return nil, err
        }
-       msg.Ephemeral = handshake.localEphemeral.Public()
+       msg.Ephemeral = handshake.localEphemeral.publicKey()
        handshake.mixHash(msg.Ephemeral[:])
        handshake.mixKey(msg.Ephemeral[:])
 
        func() {
-               ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral)
+               ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
                handshake.mixKey(ss[:])
-               ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
+               ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
                handshake.mixKey(ss[:])
        }()
 
@@ -458,13 +457,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
 
                func() {
-                       ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral)
+                       ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
 
                func() {
-                       ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
+                       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
diff --git a/device/noise-types.go b/device/noise-types.go
new file mode 100644 (file)
index 0000000..a1976ff
--- /dev/null
@@ -0,0 +1,91 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+       "crypto/subtle"
+       "encoding/hex"
+       "errors"
+
+       "golang.org/x/crypto/chacha20poly1305"
+)
+
+const (
+       NoisePublicKeySize  = 32
+       NoisePrivateKeySize = 32
+)
+
+type (
+       NoisePublicKey    [NoisePublicKeySize]byte
+       NoisePrivateKey   [NoisePrivateKeySize]byte
+       NoiseSymmetricKey [chacha20poly1305.KeySize]byte
+       NoiseNonce        uint64 // padded to 12-bytes
+)
+
+func loadExactHex(dst []byte, src string) error {
+       slice, err := hex.DecodeString(src)
+       if err != nil {
+               return err
+       }
+       if len(slice) != len(dst) {
+               return errors.New("hex string does not fit the slice")
+       }
+       copy(dst, slice)
+       return nil
+}
+
+func (key NoisePrivateKey) IsZero() bool {
+       var zero NoisePrivateKey
+       return key.Equals(zero)
+}
+
+func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
+       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
+func (key *NoisePrivateKey) FromHex(src string) (err error) {
+       err = loadExactHex(key[:], src)
+       key.clamp()
+       return
+}
+
+func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
+       err = loadExactHex(key[:], src)
+       if key.IsZero() {
+               return
+       }
+       key.clamp()
+       return
+}
+
+func (key NoisePrivateKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
+
+func (key *NoisePublicKey) FromHex(src string) error {
+       return loadExactHex(key[:], src)
+}
+
+func (key NoisePublicKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
+
+func (key NoisePublicKey) IsZero() bool {
+       var zero NoisePublicKey
+       return key.Equals(zero)
+}
+
+func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
+       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
+func (key *NoiseSymmetricKey) FromHex(src string) error {
+       return loadExactHex(key[:], src)
+}
+
+func (key NoiseSymmetricKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
index e431588a68de417ed2915a5e086a27dc520ba7ae..6ba3f2e6246e39a1293c69e5cf598fd35ad7a359 100644 (file)
@@ -11,6 +11,24 @@ import (
        "testing"
 )
 
+func TestCurveWrappers(t *testing.T) {
+       sk1, err := newPrivateKey()
+       assertNil(t, err)
+
+       sk2, err := newPrivateKey()
+       assertNil(t, err)
+
+       pk1 := sk1.publicKey()
+       pk2 := sk2.publicKey()
+
+       ss1 := sk1.sharedSecret(pk2)
+       ss2 := sk2.sharedSecret(pk1)
+
+       if ss1 != ss2 {
+               t.Fatal("Failed to compute shared secet")
+       }
+}
+
 func TestNoiseHandshake(t *testing.T) {
        dev1 := randDevice(t)
        dev2 := randDevice(t)
@@ -18,14 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
        defer dev1.Close()
        defer dev2.Close()
 
-       peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public())
-       if err != nil {
-               t.Fatal(err)
-       }
-       peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public())
-       if err != nil {
-               t.Fatal(err)
-       }
+       peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
+       peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
 
        assertEqual(
                t,
index 3ec625f8a79290ced139c5d93208a85b05dc240f..a96f2612a1a5544b19224a24626f06089ad8c6f4 100644 (file)
@@ -14,7 +14,6 @@ import (
        "time"
 
        "golang.zx2c4.com/wireguard/conn"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 const (
@@ -77,8 +76,7 @@ type Peer struct {
        cookieGenerator CookieGenerator
 }
 
-func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) {
-
+func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        if device.isClosed.Get() {
                return nil, errors.New("device closed")
        }
@@ -118,7 +116,7 @@ func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) {
 
        handshake := &peer.handshake
        handshake.mutex.Lock()
-       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk)
+       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
        handshake.remoteStatic = pk
        handshake.mutex.Unlock()
 
index b266f4cce6b09deb38c313cea7a1bb01190ed342..1671faa30d61fbd8e35d236f21425a59660f05bf 100644 (file)
@@ -18,7 +18,6 @@ import (
 
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ipc"
-       "golang.zx2c4.com/wireguard/wgcfg"
 )
 
 type IPCError struct {
@@ -55,7 +54,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
                // serialize device related values
 
                if !device.staticIdentity.privateKey.IsZero() {
-                       send("private_key=" + device.staticIdentity.privateKey.HexString())
+                       send("private_key=" + device.staticIdentity.privateKey.ToHex())
                }
 
                if device.net.port != 0 {
@@ -72,8 +71,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
                        peer.RLock()
                        defer peer.RUnlock()
 
-                       send("public_key=" + peer.handshake.remoteStatic.HexString())
-                       send("preshared_key=" + peer.handshake.presharedKey.HexString())
+                       send("public_key=" + peer.handshake.remoteStatic.ToHex())
+                       send("preshared_key=" + peer.handshake.presharedKey.ToHex())
                        send("protocol_version=1")
                        if peer.endpoint != nil {
                                send("endpoint=" + peer.endpoint.DstToString())
@@ -140,7 +139,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
 
                        switch key {
                        case "private_key":
-                               sk, err := wgcfg.ParsePrivateHexKey(value)
+                               var sk NoisePrivateKey
+                               err := sk.FromMaybeZeroHex(value)
                                if err != nil {
                                        logError.Println("Failed to set private_key:", err)
                                        return &IPCError{ipc.IpcErrorInvalid}
@@ -221,7 +221,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                        switch key {
 
                        case "public_key":
-                               publicKey, err := wgcfg.ParseHexKey(value)
+                               var publicKey NoisePublicKey
+                               err := publicKey.FromHex(value)
                                if err != nil {
                                        logError.Println("Failed to get peer by public key:", err)
                                        return &IPCError{ipc.IpcErrorInvalid}
@@ -230,7 +231,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                                // ignore peer with public key of device
 
                                device.staticIdentity.RLock()
-                               dummy = device.staticIdentity.publicKey.Equal(publicKey)
+                               dummy = device.staticIdentity.publicKey.Equals(publicKey)
                                device.staticIdentity.RUnlock()
 
                                if dummy {
@@ -290,8 +291,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
                                logDebug.Println(peer, "- UAPI: Updating preshared key")
 
                                peer.handshake.mutex.Lock()
-                               key, err := wgcfg.ParseSymmetricHexKey(value)
-                               peer.handshake.presharedKey = key
+                               err := peer.handshake.presharedKey.FromHex(value)
                                peer.handshake.mutex.Unlock()
 
                                if err != nil {