]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
noise: unify zero checking of ecdh
authorJason A. Donenfeld <Jason@zx2c4.com>
Wed, 18 Mar 2020 05:06:56 +0000 (23:06 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 18 Mar 2020 05:07:14 +0000 (23:07 -0600)
device/device.go
device/noise-protocol.go
device/peer.go

index 0b909a77d5d85e1c44433896867cf906cd3328de..8c08f1c34f8c679326c7cb3f5a888e70d6788f0f 100644 (file)
@@ -240,9 +240,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        for _, peer := range device.peers.keyMap {
                handshake := &peer.handshake
                handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
-               if isZero(handshake.precomputedStaticStatic[:]) {
-                       panic("an invalid peer public key made it into the configuration")
-               }
                expiredPeers = append(expiredPeers, peer)
        }
 
index 1c08e0a5f0300271b43ed8e5db22018a8f1fe6a7..ee327d2eb551783d8669fdd1dae7180987fc173d 100644 (file)
@@ -154,6 +154,7 @@ func init() {
 }
 
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+       var errZeroECDHResult = errors.New("ECDH returned all zeros")
 
        device.staticIdentity.RLock()
        defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
 
-       if isZero(handshake.precomputedStaticStatic[:]) {
-               return nil, errors.New("static shared secret is zero")
-       }
-
        // create ephemeral key
-
        var err error
        handshake.hash = InitialHash
        handshake.chainKey = InitialChainKey
@@ -176,56 +172,53 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
                return nil, err
        }
 
-       // assign index
-
-       device.indexTable.Delete(handshake.localIndex)
-       handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
-
-       if err != nil {
-               return nil, err
-       }
-
        handshake.mixHash(handshake.remoteStatic[:])
 
        msg := MessageInitiation{
                Type:      MessageInitiationType,
                Ephemeral: handshake.localEphemeral.publicKey(),
-               Sender:    handshake.localIndex,
        }
 
        handshake.mixKey(msg.Ephemeral[:])
        handshake.mixHash(msg.Ephemeral[:])
 
        // encrypt static key
-
-       func() {
-               var key [chacha20poly1305.KeySize]byte
-               ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
-               KDF2(
-                       &handshake.chainKey,
-                       &key,
-                       handshake.chainKey[:],
-                       ss[:],
-               )
-               aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
-       }()
+       ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+       if isZero(ss[:]) {
+               return nil, errZeroECDHResult
+       }
+       var key [chacha20poly1305.KeySize]byte
+       KDF2(
+               &handshake.chainKey,
+               &key,
+               handshake.chainKey[:],
+               ss[:],
+       )
+       aead, _ := chacha20poly1305.New(key[:])
+       aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
        handshake.mixHash(msg.Static[:])
 
        // encrypt timestamp
-
+       if isZero(handshake.precomputedStaticStatic[:]) {
+               return nil, errZeroECDHResult
+       }
+       KDF2(
+               &handshake.chainKey,
+               &key,
+               handshake.chainKey[:],
+               handshake.precomputedStaticStatic[:],
+       )
        timestamp := tai64n.Now()
-       func() {
-               var key [chacha20poly1305.KeySize]byte
-               KDF2(
-                       &handshake.chainKey,
-                       &key,
-                       handshake.chainKey[:],
-                       handshake.precomputedStaticStatic[:],
-               )
-               aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
-       }()
+       aead, _ = chacha20poly1305.New(key[:])
+       aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+
+       // assign index
+       device.indexTable.Delete(handshake.localIndex)
+       msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+       if err != nil {
+               return nil, err
+       }
+       handshake.localIndex = msg.Sender
 
        handshake.mixHash(msg.Timestamp[:])
        handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
 
        // decrypt static key
-
        var err error
        var peerPK NoisePublicKey
-       func() {
-               var key [chacha20poly1305.KeySize]byte
-               ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
-               KDF2(&chainKey, &key, chainKey[:], ss[:])
-               aead, _ := chacha20poly1305.New(key[:])
-               _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
-       }()
+       var key [chacha20poly1305.KeySize]byte
+       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+       if isZero(ss[:]) {
+               return nil
+       }
+       KDF2(&chainKey, &key, chainKey[:], ss[:])
+       aead, _ := chacha20poly1305.New(key[:])
+       _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
        if err != nil {
                return nil
        }
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        }
 
        handshake := &peer.handshake
-       if isZero(handshake.precomputedStaticStatic[:]) {
-               return nil
-       }
 
        // verify identity
 
        var timestamp tai64n.Timestamp
-       var key [chacha20poly1305.KeySize]byte
 
        handshake.mutex.RLock()
+
+       if isZero(handshake.precomputedStaticStatic[:]) {
+               handshake.mutex.RUnlock()
+               return nil
+       }
        KDF2(
                &chainKey,
                &key,
                chainKey[:],
                handshake.precomputedStaticStatic[:],
        )
-       aead, _ := chacha20poly1305.New(key[:])
+       aead, _ = chacha20poly1305.New(key[:])
        _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
        if err != nil {
                handshake.mutex.RUnlock()
index 91d975aa34832fafabef66c6cc91df3e97b7bccf..8a8224c66204ff02683289ff1f7c7901b9481c4f 100644 (file)
@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        handshake := &peer.handshake
        handshake.mutex.Lock()
        handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
-       ssIsZero := isZero(handshake.precomputedStaticStatic[:])
        handshake.remoteStatic = pk
        handshake.mutex.Unlock()
 
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        peer.endpoint = nil
 
-       // conditionally add
+       // add
 
-       if !ssIsZero {
-               device.peers.keyMap[pk] = peer
-       } else {
-               return nil, nil
-       }
+       device.peers.keyMap[pk] = peer
 
        // start peer