]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
More odds and ends
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 17:50:58 +0000 (19:50 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 17:52:13 +0000 (19:52 +0200)
noise-protocol.go
noise_test.go
receive.go
send.go

index 82d553e107b74584b7e0e46ac592a6e169855fe5..f72dcc480cf2c839aff531ae7d54918cbf0c367b 100644 (file)
@@ -319,6 +319,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
 
        handshake.mutex.Unlock()
 
+       setZero(hash[:])
+       setZero(chainKey[:])
+
        return peer
 }
 
@@ -362,7 +365,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
                handshake.mixKey(ss[:])
        }()
 
-       // add preshared key (psk)
+       // add preshared key
 
        var tau [blake2s.Size]byte
        var key [chacha20poly1305.KeySize]byte
@@ -457,7 +460,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                aead, _ := chacha20poly1305.New(key[:])
                _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
                if err != nil {
-                       device.log.Debug.Println("failed to open")
                        return false
                }
                mixHash(&hash, &hash, msg.Empty[:])
@@ -485,10 +487,10 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        return lookup.peer
 }
 
-/* Derives a new key-pair from the current handshake state
+/* Derives a new keypair from the current handshake state
  *
  */
-func (peer *Peer) NewKeypair() *Keypair {
+func (peer *Peer) DeriveNewKeypair() error {
        device := peer.device
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -517,12 +519,13 @@ func (peer *Peer) NewKeypair() *Keypair {
                )
                isInitiator = false
        } else {
-               return nil
+               return errors.New("invalid state for keypair derivation")
        }
 
        // zero handshake
 
        setZero(handshake.chainKey[:])
+       setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
        setZero(handshake.localEphemeral[:])
        peer.handshake.state = HandshakeZeroed
 
@@ -576,5 +579,23 @@ func (peer *Peer) NewKeypair() *Keypair {
        }
        kp.mutex.Unlock()
 
-       return keypair
+       return nil
+}
+
+func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
+       kp := &peer.keypairs
+       if kp.next != receivedKeypair {
+               return false
+       }
+       kp.mutex.Lock()
+       defer kp.mutex.Unlock()
+       if kp.next != receivedKeypair {
+               return false
+       }
+       old := kp.previous
+       kp.previous = kp.current
+       peer.device.DeleteKeypair(old)
+       kp.current = kp.next
+       kp.next = nil
+       return true
 }
index 37bfb94710062d2fff63083e666f0d186358ff08..ce32097fd6896be71ab21a4b98ae41b648dd6737 100644 (file)
@@ -102,15 +102,15 @@ func TestNoiseHandshake(t *testing.T) {
 
        t.Log("deriving keys")
 
-       key1 := peer1.NewKeypair()
-       key2 := peer2.NewKeypair()
+       key1 := peer1.DeriveNewKeypair()
+       key2 := peer2.DeriveNewKeypair()
 
        if key1 == nil {
-               t.Fatal("failed to dervice key-pair for peer 1")
+               t.Fatal("failed to dervice keypair for peer 1")
        }
 
        if key2 == nil {
-               t.Fatal("failed to dervice key-pair for peer 2")
+               t.Fatal("failed to dervice keypair for peer 2")
        }
 
        // encrypting / decryption test
index 32ff512508c7c03aa5fc3cc96323ce8a0be8394f..64253e6eeadeb7da0e63013f401556665752a094 100644 (file)
@@ -189,7 +189,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                                continue
                        }
 
-                       // check key-pair expiry
+                       // check keypair expiry
 
                        if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
                                continue
@@ -475,7 +475,7 @@ func (device *Device) RoutineHandshake() {
                                continue
                        }
 
-                       if peer.NewKeypair() == nil {
+                       if peer.DeriveNewKeypair() != nil {
                                continue
                        }
 
@@ -532,9 +532,9 @@ func (device *Device) RoutineHandshake() {
                        peer.timersAnyAuthenticatedPacketTraversal()
                        peer.timersAnyAuthenticatedPacketReceived()
 
-                       // derive key-pair
+                       // derive keypair
 
-                       if peer.NewKeypair() == nil {
+                       if peer.DeriveNewKeypair() != nil {
                                continue
                        }
 
@@ -597,25 +597,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
-                       // check if using new key-pair
-
-                       kp := &peer.keypairs
-                       if kp.next == elem.keypair {
-                               kp.mutex.Lock()
-                               if kp.next != elem.keypair {
-                                       kp.mutex.Unlock()
-                               } else {
-                                       old := kp.previous
-                                       kp.previous = kp.current
-                                       device.DeleteKeypair(old)
-                                       kp.current = kp.next
-                                       kp.next = nil
-                                       kp.mutex.Unlock()
-                                       peer.timersHandshakeComplete()
-                                       select {
-                                       case peer.signals.newKeypairArrived <- struct{}{}:
-                                       default:
-                                       }
+                       // check if using new keypair
+                       if peer.ReceivedWithKeypair(elem.keypair) {
+                               peer.timersHandshakeComplete()
+                               select {
+                               case peer.signals.newKeypairArrived <- struct{}{}:
+                               default:
                                }
                        }
 
diff --git a/send.go b/send.go
index 35e0d0008f581ed802c8f999841fdd44eee76bef..a8ec28cf4d36d8c3d7a7c5ad1c7c9c0f6a046d94 100644 (file)
--- a/send.go
+++ b/send.go
@@ -47,7 +47,7 @@ type QueueOutboundElement struct {
        buffer  *[MaxMessageSize]byte // slice holding the packet data
        packet  []byte                // slice of "buffer" (always!)
        nonce   uint64                // nonce for encryption
-       keypair *Keypair              // key-pair for encryption
+       keypair *Keypair              // keypair for encryption
        peer    *Peer                 // related peer
 }
 
@@ -306,11 +306,11 @@ func (peer *Peer) RoutineNonce() {
 
                                peer.SendHandshakeInitiation(false)
 
-                               logDebug.Println(peer, ": Awaiting key-pair")
+                               logDebug.Println(peer, ": Awaiting keypair")
 
                                select {
                                case <-peer.signals.newKeypairArrived:
-                                       logDebug.Println(peer, ": Obtained awaited key-pair")
+                                       logDebug.Println(peer, ": Obtained awaited keypair")
                                case <-peer.signals.flushNonceQueue:
                                        for {
                                                select {