]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Improved handling of key-material
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 1 Sep 2017 12:21:53 +0000 (14:21 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 1 Sep 2017 12:21:53 +0000 (14:21 +0200)
src/keypair.go
src/noise_helpers.go
src/noise_protocol.go
src/receive.go
src/send.go
src/timers.go
src/tun_linux.go

index ba9c4379787252215d132dc2221ca1cba7bfe40a..644d040b9ba342ffc84384d375b9d034bfc1fdcf 100644 (file)
@@ -2,14 +2,39 @@ package main
 
 import (
        "crypto/cipher"
+       "golang.org/x/crypto/chacha20poly1305"
+       "reflect"
        "sync"
        "time"
 )
 
+type safeAEAD struct {
+       mutex sync.RWMutex
+       aead  cipher.AEAD
+}
+
+func (con *safeAEAD) clear() {
+       // TODO: improve handling of key material
+       con.mutex.Lock()
+       if con.aead != nil {
+               val := reflect.ValueOf(con.aead)
+               elm := val.Elem()
+               typ := elm.Type()
+               elm.Set(reflect.Zero(typ))
+               con.aead = nil
+       }
+       con.mutex.Unlock()
+}
+
+func (con *safeAEAD) setKey(key *[chacha20poly1305.KeySize]byte) {
+       // TODO: improve handling of key material
+       con.aead, _ = chacha20poly1305.New(key[:])
+}
+
 type KeyPair struct {
-       receive      cipher.AEAD
+       send         safeAEAD
+       receive      safeAEAD
        replayFilter ReplayFilter
-       send         cipher.AEAD
        sendNonce    uint64
        isInitiator  bool
        created      time.Time
@@ -31,7 +56,7 @@ func (kp *KeyPairs) Current() *KeyPair {
 }
 
 func (device *Device) DeleteKeyPair(key *KeyPair) {
-       key.send = nil
-       key.receive = nil
+       key.send.clear()
+       key.receive.clear()
        device.indices.Delete(key.localIndex)
 }
index 105f78f4829096c3af5b9e6be4b71e35efef260a..24302c0c7ac04bdc13b620412b2516ff5f838a4f 100644 (file)
@@ -13,37 +13,47 @@ import (
  * https://tools.ietf.org/html/rfc5869
  */
 
-func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) {
+func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
        mac := hmac.New(func() hash.Hash {
                h, _ := blake2s.New256(nil)
                return h
        }, key)
-       mac.Write(input)
+       mac.Write(in0)
        mac.Sum(sum[:0])
 }
 
-func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) {
-       HMAC(&t0, key, input)
-       HMAC(&t0, t0[:], []byte{0x1})
+func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
+       mac := hmac.New(func() hash.Hash {
+               h, _ := blake2s.New256(nil)
+               return h
+       }, key)
+       mac.Write(in0)
+       mac.Write(in1)
+       mac.Sum(sum[:0])
+}
+
+func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
+       HMAC1(t0, key, input)
+       HMAC1(t0, t0[:], []byte{0x1})
        return
 }
 
-func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) {
+func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
        var prk [blake2s.Size]byte
-       HMAC(&prk, key, input)
-       HMAC(&t0, prk[:], []byte{0x1})
-       HMAC(&t1, prk[:], append(t0[:], 0x2))
-       prk = [blake2s.Size]byte{}
+       HMAC1(&prk, key, input)
+       HMAC1(t0, prk[:], []byte{0x1})
+       HMAC2(t1, prk[:], t0[:], []byte{0x2})
+       setZero(prk[:])
        return
 }
 
-func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) {
+func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
        var prk [blake2s.Size]byte
-       HMAC(&prk, key, input)
-       HMAC(&t0, prk[:], []byte{0x1})
-       HMAC(&t1, prk[:], append(t0[:], 0x2))
-       HMAC(&t2, prk[:], append(t1[:], 0x3))
-       prk = [blake2s.Size]byte{}
+       HMAC1(&prk, key, input)
+       HMAC1(t0, prk[:], []byte{0x1})
+       HMAC2(t1, prk[:], t0[:], []byte{0x2})
+       HMAC2(t2, prk[:], t1[:], []byte{0x3})
+       setZero(prk[:])
        return
 }
 
@@ -55,6 +65,12 @@ func isZero(val []byte) bool {
        return acc == 0
 }
 
+func setZero(arr []byte) {
+       for i := range arr {
+               arr[i] = 0
+       }
+}
+
 /* curve25519 wrappers */
 
 func newPrivateKey() (sk NoisePrivateKey, err error) {
index 1f1301ef7139fdbf2d4db53862e109cdd673fd9c..a50e3dcb6b16ef727bc7076471d5329532e70985 100644 (file)
@@ -109,27 +109,31 @@ var (
        ZeroNonce       [chacha20poly1305.NonceSize]byte
 )
 
-func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-       return KDF1(c[:], data)
+func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+       KDF1(dst, c[:], data)
 }
 
-func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-       return blake2s.Sum256(append(h[:], data...))
+func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+       hsh, _ := blake2s.New256(nil)
+       hsh.Write(h[:])
+       hsh.Write(data)
+       hsh.Sum(dst[:0])
+       hsh.Reset()
 }
 
 func (h *Handshake) mixHash(data []byte) {
-       h.hash = mixHash(h.hash, data)
+       mixHash(&h.hash, &h.hash, data)
 }
 
 func (h *Handshake) mixKey(data []byte) {
-       h.chainKey = mixKey(h.chainKey, data)
+       mixKey(&h.chainKey, &h.chainKey, data)
 }
 
 /* Do basic precomputations
  */
 func init() {
        InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
-       InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
+       mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
 }
 
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
@@ -176,7 +180,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        func() {
                var key [chacha20poly1305.KeySize]byte
                ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
-               handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
+               KDF2(
+                       &handshake.chainKey,
+                       &key,
+                       handshake.chainKey[:],
+                       ss[:],
+               )
                aead, _ := chacha20poly1305.New(key[:])
                aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
        }()
@@ -187,7 +196,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        timestamp := Timestamp()
        func() {
                var key [chacha20poly1305.KeySize]byte
-               handshake.chainKey, key = KDF2(
+               KDF2(
+                       &handshake.chainKey,
+                       &key,
                        handshake.chainKey[:],
                        handshake.precomputedStaticStatic[:],
                )
@@ -197,7 +208,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 
        handshake.mixHash(msg.Timestamp[:])
        handshake.state = HandshakeInitiationCreated
-
        return &msg, nil
 }
 
@@ -206,9 +216,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
                return nil
        }
 
-       hash := mixHash(InitialHash, device.publicKey[:])
-       hash = mixHash(hash, msg.Ephemeral[:])
-       chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
+       var (
+               hash     [blake2s.Size]byte
+               chainKey [blake2s.Size]byte
+       )
+
+       mixHash(&hash, &InitialHash, device.publicKey[:])
+       mixHash(&hash, &hash, msg.Ephemeral[:])
+       mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
 
        // decrypt static key
 
@@ -217,14 +232,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        func() {
                var key [chacha20poly1305.KeySize]byte
                ss := device.privateKey.sharedSecret(msg.Ephemeral)
-               chainKey, key = KDF2(chainKey[:], ss[:])
+               KDF2(&chainKey, &key, chainKey[:], ss[:])
                aead, _ := chacha20poly1305.New(key[:])
                _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
        }()
        if err != nil {
                return nil
        }
-       hash = mixHash(hash, msg.Static[:])
+       mixHash(&hash, &hash, msg.Static[:])
 
        // lookup peer
 
@@ -244,7 +259,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        var key [chacha20poly1305.KeySize]byte
 
        handshake.mutex.RLock()
-       chainKey, key = KDF2(
+       KDF2(
+               &chainKey,
+               &key,
                chainKey[:],
                handshake.precomputedStaticStatic[:],
        )
@@ -254,7 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
                handshake.mutex.RUnlock()
                return nil
        }
-       hash = mixHash(hash, msg.Timestamp[:])
+       mixHash(&hash, &hash, msg.Timestamp[:])
 
        // protect against replay & flood
 
@@ -327,7 +344,15 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
        var tau [blake2s.Size]byte
        var key [chacha20poly1305.KeySize]byte
-       handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
+
+       KDF3(
+               &handshake.chainKey,
+               &tau,
+               &key,
+               handshake.chainKey[:],
+               handshake.presharedKey[:],
+       )
+
        handshake.mixHash(tau[:])
 
        func() {
@@ -337,6 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        }()
 
        handshake.state = HandshakeResponseCreated
+
        return &msg, nil
 }
 
@@ -371,22 +397,33 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 
                // finish 3-way DH
 
-               hash = mixHash(handshake.hash, msg.Ephemeral[:])
-               chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:])
+               mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
+               mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
 
                func() {
                        ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
-                       chainKey = mixKey(chainKey, ss[:])
-                       ss = device.privateKey.sharedSecret(msg.Ephemeral)
-                       chainKey = mixKey(chainKey, ss[:])
+                       mixKey(&chainKey, &chainKey, ss[:])
+                       setZero(ss[:])
+               }()
+
+               func() {
+                       ss := device.privateKey.sharedSecret(msg.Ephemeral)
+                       mixKey(&chainKey, &chainKey, ss[:])
+                       setZero(ss[:])
                }()
 
                // add preshared key (psk)
 
                var tau [blake2s.Size]byte
                var key [chacha20poly1305.KeySize]byte
-               chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
-               hash = mixHash(hash, tau[:])
+               KDF3(
+                       &chainKey,
+                       &tau,
+                       &key,
+                       chainKey[:],
+                       handshake.presharedKey[:],
+               )
+               mixHash(&hash, &hash, tau[:])
 
                // authenticate
 
@@ -396,7 +433,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                        device.log.Debug.Println("failed to open")
                        return false
                }
-               hash = mixHash(hash, msg.Empty[:])
+               mixHash(&hash, &hash, msg.Empty[:])
                return true
        }()
 
@@ -415,6 +452,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 
        handshake.mutex.Unlock()
 
+       setZero(hash[:])
+       setZero(chainKey[:])
+
        return lookup.peer
 }
 
@@ -422,6 +462,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
  *
  */
 func (peer *Peer) NewKeyPair() *KeyPair {
+       device := peer.device
        handshake := &peer.handshake
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
@@ -433,10 +474,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        var recvKey [chacha20poly1305.KeySize]byte
 
        if handshake.state == HandshakeResponseConsumed {
-               sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+               KDF2(
+                       &sendKey,
+                       &recvKey,
+                       handshake.chainKey[:],
+                       nil,
+               )
                isInitiator = true
        } else if handshake.state == HandshakeResponseCreated {
-               recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+               KDF2(
+                       &recvKey,
+                       &sendKey,
+                       handshake.chainKey[:],
+                       nil,
+               )
                isInitiator = false
        } else {
                return nil
@@ -444,16 +495,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
        // zero handshake
 
-       handshake.chainKey = [blake2s.Size]byte{}
-       handshake.localEphemeral = NoisePrivateKey{}
+       setZero(handshake.chainKey[:])
+       setZero(handshake.localEphemeral[:])
        peer.handshake.state = HandshakeZeroed
 
        // create AEAD instances
 
        keyPair := new(KeyPair)
+       keyPair.send.setKey(&sendKey)
+       keyPair.receive.setKey(&recvKey)
+
+       setZero(sendKey[:])
+       setZero(recvKey[:])
+
        keyPair.created = time.Now()
-       keyPair.send, _ = chacha20poly1305.New(sendKey[:])
-       keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
        keyPair.sendNonce = 0
        keyPair.replayFilter.Init()
        keyPair.isInitiator = isInitiator
@@ -462,12 +517,14 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
        // remap index
 
-       indices := &peer.device.indices
-       indices.Insert(handshake.localIndex, IndexTableEntry{
-               peer:      peer,
-               keyPair:   keyPair,
-               handshake: nil,
-       })
+       device.indices.Insert(
+               handshake.localIndex,
+               IndexTableEntry{
+                       peer:      peer,
+                       keyPair:   keyPair,
+                       handshake: nil,
+               },
+       )
        handshake.localIndex = 0
 
        // rotate key pairs
@@ -479,7 +536,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
                // TODO: Adapt kernel behavior noise.c:161
                if isInitiator {
                        if kp.previous != nil {
-                               indices.Delete(kp.previous.localIndex)
+                               device.DeleteKeyPair(kp.previous)
+                               kp.previous = nil
                        }
 
                        if kp.next != nil {
index ca7bb6ec53fd2e6a763c3016f4d2d62576b34a27..97646d88712594148f34c5234ffcef92c1a0b3ba 100644 (file)
@@ -251,15 +251,22 @@ func (device *Device) RoutineDecryption() {
                var err error
                copy(nonce[4:], counter)
                elem.counter = binary.LittleEndian.Uint64(counter)
-               elem.packet, err = elem.keyPair.receive.Open(
-                       elem.buffer[:0],
-                       nonce[:],
-                       content,
-                       nil,
-               )
-               if err != nil {
+               elem.keyPair.receive.mutex.RLock()
+               if elem.keyPair.receive.aead == nil {
+                       // very unlikely (the key was deleted during queuing)
                        elem.Drop()
+               } else {
+                       elem.packet, err = elem.keyPair.receive.aead.Open(
+                               elem.buffer[:0],
+                               nonce[:],
+                               content,
+                               nil,
+                       )
+                       if err != nil {
+                               elem.Drop()
+                       }
                }
+               elem.keyPair.receive.mutex.RUnlock()
                elem.mutex.Unlock()
        }
 }
@@ -507,6 +514,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
                kp.mutex.Lock()
                if kp.next == elem.keyPair {
                        peer.TimerHandshakeComplete()
+                       if kp.previous != nil {
+                               device.DeleteKeyPair(kp.previous)
+                       }
                        kp.previous = kp.current
                        kp.current = kp.next
                        kp.next = nil
index 7d4014a9b0541c03742eb9fbecb80fbe6cbe6d13..c598ad417a8d119c87961a172112391fea75e034 100644 (file)
@@ -349,12 +349,19 @@ func (device *Device) RoutineEncryption() {
                // encrypt content (append to header)
 
                binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
-               elem.packet = elem.keyPair.send.Seal(
-                       header,
-                       nonce[:],
-                       elem.packet,
-                       nil,
-               )
+               elem.keyPair.send.mutex.RLock()
+               if elem.keyPair.send.aead == nil {
+                       // very unlikely (the key was deleted during queuing)
+                       elem.Drop()
+               } else {
+                       elem.packet = elem.keyPair.send.aead.Seal(
+                               header,
+                               nonce[:],
+                               elem.packet,
+                               nil,
+                       )
+               }
+               elem.keyPair.send.mutex.RUnlock()
                elem.mutex.Unlock()
 
                // refresh key if necessary
index de54a96db934e7e25a545592db0b9096a8d4add5..ad8866f4ec4759a478fc85d8a876ff96bee63c18 100644 (file)
@@ -3,7 +3,6 @@ package main
 import (\r
        "bytes"\r
        "encoding/binary"\r
-       "golang.org/x/crypto/blake2s"\r
        "math/rand"\r
        "sync/atomic"\r
        "time"\r
@@ -134,7 +133,6 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
 \r
 func (peer *Peer) RoutineTimerHandler() {\r
        device := peer.device\r
-       indices := &device.indices\r
 \r
        logDebug := device.log.Debug\r
        logDebug.Println("Routine, timer handler, started for peer", peer.String())\r
@@ -186,35 +184,31 @@ func (peer *Peer) RoutineTimerHandler() {
                        kp := &peer.keyPairs\r
                        kp.mutex.Lock()\r
 \r
-                       // unmap indecies\r
+                       // remove key-pairs\r
 \r
-                       indices.mutex.Lock()\r
                        if kp.previous != nil {\r
-                               delete(indices.table, kp.previous.localIndex)\r
+                               device.DeleteKeyPair(kp.previous)\r
+                               kp.previous = nil\r
                        }\r
                        if kp.current != nil {\r
-                               delete(indices.table, kp.current.localIndex)\r
+                               device.DeleteKeyPair(kp.current)\r
+                               kp.current = nil\r
                        }\r
                        if kp.next != nil {\r
-                               delete(indices.table, kp.next.localIndex)\r
+                               device.DeleteKeyPair(kp.next)\r
+                               kp.next = nil\r
                        }\r
-                       delete(indices.table, hs.localIndex)\r
-                       indices.mutex.Unlock()\r
-\r
-                       // zero out key pairs (TODO: better than wait for GC)\r
-\r
-                       kp.current = nil\r
-                       kp.previous = nil\r
-                       kp.next = nil\r
                        kp.mutex.Unlock()\r
 \r
                        // zero out handshake\r
 \r
+                       device.indices.Delete(hs.localIndex)\r
+\r
                        hs.localIndex = 0\r
-                       hs.localEphemeral = NoisePrivateKey{}\r
-                       hs.remoteEphemeral = NoisePublicKey{}\r
-                       hs.chainKey = [blake2s.Size]byte{}\r
-                       hs.hash = [blake2s.Size]byte{}\r
+                       setZero(hs.localEphemeral[:])\r
+                       setZero(hs.remoteEphemeral[:])\r
+                       setZero(hs.chainKey[:])\r
+                       setZero(hs.hash[:])\r
                        hs.mutex.Unlock()\r
                }\r
        }\r
index b9541c9d02e849332fadb2c1f7e3eb0f773897cb..58a762ad56e636d3938c723c9701ed3054c68be1 100644 (file)
@@ -63,6 +63,8 @@ func (tun *NativeTun) RoutineNetlinkListener() {
                return
        }
 
+       tun.events <- TUNEventUp // TODO: Fix network namespace problem
+
        for msg := make([]byte, 1<<16); ; {
 
                msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)