]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: uniformly check ECDH output for zeros
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 16 Feb 2023 14:51:30 +0000 (15:51 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 16 Feb 2023 15:33:14 +0000 (16:33 +0100)
For some reason, this was omitted for response messages.

Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/device.go
device/noise-helpers.go
device/noise-protocol.go
device/noise_test.go
device/peer.go

index 8e55724327c04bfa203521d84fcd8adab7ce6a1f..3368a930b5cfc9804571e87aae5724c4e82908d3 100644 (file)
@@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
        for _, peer := range device.peers.keyMap {
                handshake := &peer.handshake
-               handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+               handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
                expiredPeers = append(expiredPeers, peer)
        }
 
index 729f8b005c54f72af5c398f9e609579b87cf9822..c2f356b96d32f2f4837454bd19b15689960a42e7 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto/hmac"
        "crypto/rand"
        "crypto/subtle"
+       "errors"
        "hash"
 
        "golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
        return
 }
 
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
        apk := (*[NoisePublicKeySize]byte)(&pk)
        ask := (*[NoisePrivateKeySize]byte)(sk)
        curve25519.ScalarMult(&ss, ask, apk)
-       return ss
+       if isZero(ss[:]) {
+               return ss, errInvalidPublicKey
+       }
+       return ss, nil
 }
index 117e960a8b26a30bb17a4e7133f3d7732fad024d..e8f6145e50fa959bf601d73ac4ebab535e9fcf76 100644 (file)
@@ -175,8 +175,6 @@ func init() {
 }
 
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
-       errZeroECDHResult := errors.New("ECDH returned all zeros")
-
        device.staticIdentity.RLock()
        defer device.staticIdentity.RUnlock()
 
@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        handshake.mixHash(msg.Ephemeral[:])
 
        // encrypt static key
-       ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
-       if isZero(ss[:]) {
-               return nil, errZeroECDHResult
+       ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+       if err != nil {
+               return nil, err
        }
        var key [chacha20poly1305.KeySize]byte
        KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 
        // encrypt timestamp
        if isZero(handshake.precomputedStaticStatic[:]) {
-               return nil, errZeroECDHResult
+               return nil, errInvalidPublicKey
        }
        KDF2(
                &handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
 
        // decrypt static key
-       var err error
        var peerPK NoisePublicKey
        var key [chacha20poly1305.KeySize]byte
-       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
-       if isZero(ss[:]) {
+       ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+       if err != nil {
                return nil
        }
        KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        handshake.mixHash(msg.Ephemeral[:])
        handshake.mixKey(msg.Ephemeral[:])
 
-       func() {
-               ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
-               handshake.mixKey(ss[:])
-               ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
-               handshake.mixKey(ss[:])
-       }()
+       ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+       if err != nil {
+               return nil, err
+       }
+       handshake.mixKey(ss[:])
+       ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+       if err != nil {
+               return nil, err
+       }
+       handshake.mixKey(ss[:])
 
        // add preshared key
 
@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
        handshake.mixHash(tau[:])
 
-       func() {
-               aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
-               handshake.mixHash(msg.Empty[:])
-       }()
+       aead, _ := chacha20poly1305.New(key[:])
+       aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+       handshake.mixHash(msg.Empty[:])
 
        handshake.state = handshakeResponseCreated
 
@@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
                mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
 
-               func() {
-                       ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
-                       mixKey(&chainKey, &chainKey, ss[:])
-                       setZero(ss[:])
-               }()
+               ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+               if err != nil {
+                       return false
+               }
+               mixKey(&chainKey, &chainKey, ss[:])
+               setZero(ss[:])
 
-               func() {
-                       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
-                       mixKey(&chainKey, &chainKey, ss[:])
-                       setZero(ss[:])
-               }()
+               ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+               if err != nil {
+                       return false
+               }
+               mixKey(&chainKey, &chainKey, ss[:])
+               setZero(ss[:])
 
                // add preshared key (psk)
 
@@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                // authenticate transcript
 
                aead, _ := chacha20poly1305.New(key[:])
-               _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+               _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
                if err != nil {
                        return false
                }
index 587d1e55d8575b09eecfbf63aab750a21023d5a2..2dd53241dd66c222a46b5cf88fbdfe10cabd6b53 100644 (file)
@@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
        pk1 := sk1.publicKey()
        pk2 := sk2.publicKey()
 
-       ss1 := sk1.sharedSecret(pk2)
-       ss2 := sk2.sharedSecret(pk1)
+       ss1, err1 := sk1.sharedSecret(pk2)
+       ss2, err2 := sk2.sharedSecret(pk1)
 
-       if ss1 != ss2 {
+       if ss1 != ss2 || err1 != nil || err2 != nil {
                t.Fatal("Failed to compute shared secet")
        }
 }
index 8266dacc015091d2f632850f2293b62caff6a07c..0e7b6698542d04f00ac57358358ca8fdbbc7e15b 100644 (file)
@@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        // pre-compute DH
        handshake := &peer.handshake
        handshake.mutex.Lock()
-       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+       handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
        handshake.remoteStatic = pk
        handshake.mutex.Unlock()