]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Completed noise handshake
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 24 Jun 2017 20:03:52 +0000 (22:03 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 24 Jun 2017 20:03:52 +0000 (22:03 +0200)
src/index.go
src/keypair.go
src/noise_helpers.go
src/noise_protocol.go
src/noise_test.go

index 83a7e297997da4e36b85b64d6202437152856160..81f71e9160ac4f8fba23347db4494f16352f5781 100644 (file)
@@ -6,13 +6,15 @@ import (
 )
 
 /* Index=0 is reserved for unset indecies
+ *
+ * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
  *
  */
 
 type IndexTable struct {
        mutex      sync.RWMutex
        keypairs   map[uint32]*KeyPair
-       handshakes map[uint32]*Handshake
+       handshakes map[uint32]*Peer
 }
 
 func randUint32() (uint32, error) {
@@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
        table.mutex.Lock()
        defer table.mutex.Unlock()
        table.keypairs = make(map[uint32]*KeyPair)
-       table.handshakes = make(map[uint32]*Handshake)
+       table.handshakes = make(map[uint32]*Peer)
 }
 
-func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
+func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
        for {
@@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
                        continue
                }
 
-               // update the index
+               // clean old index
 
-               delete(table.handshakes, handshake.localIndex)
-               handshake.localIndex = id
-               table.handshakes[id] = handshake
+               delete(table.handshakes, peer.handshake.localIndex)
+               table.handshakes[id] = peer
                return id, nil
        }
 }
@@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
        return table.keypairs[id]
 }
 
-func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
+func (table *IndexTable) LookupHandshake(id uint32) *Peer {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
        return table.handshakes[id]
index 22a8244d93937ace8a39bf693956d8ef8fc739e5..e434c74dacb38676f2c199b035d2cf1d77240118 100644 (file)
@@ -5,8 +5,8 @@ import (
 )
 
 type KeyPair struct {
-       recieveKey   cipher.AEAD
-       recieveNonce NoiseNonce
-       sendKey      cipher.AEAD
-       sendNonce    NoiseNonce
+       recv      cipher.AEAD
+       recvNonce NoiseNonce
+       send      cipher.AEAD
+       sendNonce NoiseNonce
 }
index eadbc07b0e663870fd1f98efc2e6b56d0d797b81..e163acec07099bcf6df9a4c698e441723a62c507 100644 (file)
@@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
        return
 }
 
-/*
- *
- */
-
-func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-       return KDF1(c[:], data)
-}
-
-func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-       return blake2s.Sum256(append(h[:], data...))
-}
-
-/* Curve25519 wrappers
- *
- * TODO: Rethink this
- */
+/* curve25519 wrappers */
 
 func newPrivateKey() (sk NoisePrivateKey, err error) {
        // clamping: https://cr.yp.to/ecdh.html
index b9c8981e79c293c054df93ceb8b367aad2c46501..7f26cf1100577a8333335b8b871f8ca34f03b386 100644 (file)
@@ -9,9 +9,11 @@ import (
 )
 
 const (
-       HandshakeInitialCreated = iota
+       HandshakeReset = iota
+       HandshakeInitialCreated
        HandshakeInitialConsumed
        HandshakeResponseCreated
+       HandshakeResponseConsumed
 )
 
 const (
@@ -71,7 +73,6 @@ type Handshake struct {
 }
 
 var (
-       EmptyMessage   []byte
        ZeroNonce      [chacha20poly1305.NonceSize]byte
        InitalChainKey [blake2s.Size]byte
        InitalHash     [blake2s.Size]byte
@@ -82,6 +83,14 @@ func init() {
        InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 }
 
+func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+       return KDF1(c[:], data)
+}
+
+func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+       return blake2s.Sum256(append(h[:], data...))
+}
+
 func (h *Handshake) addToHash(data []byte) {
        h.hash = addToHash(h.hash, data)
 }
@@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
        h.chainKey = addToChainKey(h.chainKey, data)
 }
 
-func (device *Device) Precompute(peer *Peer) {
-       h := &peer.handshake
-       h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
-}
-
 func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 
        msg.Type = MessageInitalType
        msg.Ephemeral = handshake.localEphemeral.publicKey()
-       msg.Sender, err = device.indices.NewIndex(handshake)
+       handshake.localIndex, err = device.indices.NewIndex(peer)
 
        if err != nil {
                return nil, err
        }
 
+       msg.Sender = handshake.localIndex
        handshake.addToChainKey(msg.Ephemeral[:])
        handshake.addToHash(msg.Ephemeral[:])
 
-       // encrypt long-term "identity key"
+       // encrypt identity key
 
        func() {
                var key [chacha20poly1305.KeySize]byte
@@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
        handshake.chainKey = chainKey
        handshake.remoteIndex = msg.Sender
        handshake.remoteEphemeral = msg.Ephemeral
+       handshake.lastTimestamp = timestamp
        handshake.state = HandshakeInitialConsumed
        return peer
 }
@@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        // assign index
 
        var err error
-       var msg MessageResponse
-       msg.Type = MessageResponseType
-       msg.Sender, err = device.indices.NewIndex(handshake)
-       msg.Reciever = handshake.remoteIndex
+       handshake.localIndex, err = device.indices.NewIndex(peer)
        if err != nil {
                return nil, err
        }
 
+       var msg MessageResponse
+       msg.Type = MessageResponseType
+       msg.Sender = handshake.localIndex
+       msg.Reciever = handshake.remoteIndex
+
        // create ephemeral key
 
        handshake.localEphemeral, err = newPrivateKey()
@@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
                return nil, err
        }
        msg.Ephemeral = handshake.localEphemeral.publicKey()
+       handshake.addToHash(msg.Ephemeral[:])
 
        func() {
                ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
@@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
        func() {
                aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
+               aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
                handshake.addToHash(msg.Empty[:])
        }()
 
+       handshake.state = HandshakeResponseCreated
        return &msg, nil
 }
+
+func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
+       if msg.Type != MessageResponseType {
+               panic(errors.New("bug: invalid message type"))
+       }
+
+       // lookup handshake by reciever
+
+       peer := device.indices.LookupHandshake(msg.Reciever)
+       if peer == nil {
+               return nil
+       }
+       handshake := &peer.handshake
+       handshake.mutex.Lock()
+       defer handshake.mutex.Unlock()
+       if handshake.state != HandshakeInitialCreated {
+               return nil
+       }
+
+       // finish 3-way DH
+
+       hash := addToHash(handshake.hash, msg.Ephemeral[:])
+       chainKey := handshake.chainKey
+
+       func() {
+               ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+               chainKey = addToChainKey(chainKey, ss[:])
+               ss = device.privateKey.sharedSecret(msg.Ephemeral)
+               chainKey = addToChainKey(chainKey, ss[:])
+       }()
+
+       // add preshared key (psk)
+
+       var tau [blake2s.Size]byte
+       var key [chacha20poly1305.KeySize]byte
+       chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+       hash = addToHash(hash, tau[:])
+
+       // authenticate
+
+       aead, _ := chacha20poly1305.New(key[:])
+       _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+       if err != nil {
+               return nil
+       }
+       hash = addToHash(hash, msg.Empty[:])
+
+       // update handshake state
+
+       handshake.hash = hash
+       handshake.chainKey = chainKey
+       handshake.remoteIndex = msg.Sender
+       handshake.state = HandshakeResponseConsumed
+
+       return peer
+}
+
+func (peer *Peer) NewKeyPair() *KeyPair {
+       handshake := &peer.handshake
+       handshake.mutex.Lock()
+       defer handshake.mutex.Unlock()
+
+       // derive keys
+
+       var sendKey [chacha20poly1305.KeySize]byte
+       var recvKey [chacha20poly1305.KeySize]byte
+
+       if handshake.state == HandshakeResponseConsumed {
+               sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+       } else if handshake.state == HandshakeResponseCreated {
+               recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+       } else {
+               return nil
+       }
+
+       // create AEAD instances
+
+       var keyPair KeyPair
+       keyPair.send, _ = chacha20poly1305.New(sendKey[:])
+       keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
+       keyPair.sendNonce = 0
+       keyPair.recvNonce = 0
+
+       peer.handshake.state = HandshakeReset
+
+       return &keyPair
+}
index 8d6a0fa159252762492ad4eb878d1f8f0754eadb..ddabf8e930d148c2929db549dca12126193c2033 100644 (file)
@@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
 
        /* simulate handshake */
 
-       // Initiation message
+       // initiation message
+
+       t.Log("exchange initiation message")
 
        msg1, err := dev1.CreateMessageInitial(peer2)
        assertNil(t, err)
@@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
                peer2.handshake.hash[:],
        )
 
-       // Response message
+       // response message
+
+       t.Log("exchange response message")
+
+       msg2, err := dev2.CreateMessageResponse(peer1)
+       assertNil(t, err)
+
+       peer = dev1.ConsumeMessageResponse(msg2)
+       if peer == nil {
+               t.Fatal("handshake failed at response message")
+       }
+
+       assertEqual(
+               t,
+               peer1.handshake.chainKey[:],
+               peer2.handshake.chainKey[:],
+       )
+
+       assertEqual(
+               t,
+               peer1.handshake.hash[:],
+               peer2.handshake.hash[:],
+       )
+
+       // key pairs
+
+       t.Log("deriving keys")
+
+       key1 := peer1.NewKeyPair()
+       key2 := peer2.NewKeyPair()
+
+       if key1 == nil {
+               t.Fatal("failed to dervice key-pair for peer 1")
+       }
+
+       if key2 == nil {
+               t.Fatal("failed to dervice key-pair for peer 2")
+       }
 
+       // encrypting / decryption test
+
+       t.Log("test key pairs")
+
+       func() {
+               testMsg := []byte("wireguard test message 1")
+               var err error
+               var out []byte
+               var nonce [12]byte
+               out = key1.send.Seal(out, nonce[:], testMsg, nil)
+               out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
+               assertNil(t, err)
+               assertEqual(t, out, testMsg)
+       }()
+
+       func() {
+               testMsg := []byte("wireguard test message 2")
+               var err error
+               var out []byte
+               var nonce [12]byte
+               out = key2.send.Seal(out, nonce[:], testMsg, nil)
+               out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
+               assertNil(t, err)
+               assertEqual(t, out, testMsg)
+       }()
 }