]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added replay protection
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 10 Jul 2017 10:09:19 +0000 (12:09 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 10 Jul 2017 10:09:19 +0000 (12:09 +0200)
src/keypair.go
src/misc.go
src/noise_protocol.go
src/receive.go
src/replay.go [new file with mode: 0644]
src/replay_test.go [new file with mode: 0644]
src/timers.go

index b24dbe4fcef2725cb33406e8961c3f61a19d34a1..b5f46df14dee2ea3d3d49895766df340b4ee2ad6 100644 (file)
@@ -7,13 +7,14 @@ import (
 )
 
 type KeyPair struct {
-       receive     cipher.AEAD
-       send        cipher.AEAD
-       sendNonce   uint64
-       isInitiator bool
-       created     time.Time
-       localIndex  uint32
-       remoteIndex uint32
+       receive      cipher.AEAD
+       replayFilter ReplayFilter
+       send         cipher.AEAD
+       sendNonce    uint64
+       isInitiator  bool
+       created      time.Time
+       localIndex   uint32
+       remoteIndex  uint32
 }
 
 type KeyPairs struct {
index 75561b2ff63db82b19710eb585d11408d5644bf0..fc75c0d5d0effe887b2b8312bd8227ee07b95c88 100644 (file)
@@ -19,6 +19,13 @@ func min(a uint, b uint) uint {
        return a
 }
 
+func minUint64(a uint64, b uint64) uint64 {
+       if a > b {
+               return b
+       }
+       return a
+}
+
 func signalSend(c chan struct{}) {
        select {
        case c <- struct{}{}:
index a90fe4cc5a3a376a714f6619a2bcb8741accef33..bfa3797f2d00bc99fd73f38d286146b3002273c1 100644 (file)
@@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        return lookup.peer
 }
 
+/* Derives a new key-pair from the current handshake state
+ *
+ */
 func (peer *Peer) NewKeyPair() *KeyPair {
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        // create AEAD instances
 
        keyPair := new(KeyPair)
+       keyPair.created = time.Now()
        keyPair.send, _ = chacha20poly1305.New(sendKey[:])
        keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
        keyPair.sendNonce = 0
-       keyPair.created = time.Now()
+       keyPair.replayFilter.Init()
        keyPair.isInitiator = isInitiator
        keyPair.localIndex = peer.handshake.localIndex
        keyPair.remoteIndex = peer.handshake.remoteIndex
@@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        })
        handshake.localIndex = 0
 
-       // TODO: start timer for keypair (clearing)
-
        // rotate key pairs
 
        kp := &peer.keyPairs
index e780c66fc7d0afaf4aac5fb9b3e2ec32a3c73e12..6530c478f82288d01e8d1a5b182d0e303f84e8a2 100644 (file)
@@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        // check for replay
 
+                       if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+                               return
+                       }
+
                        // time (passive) keep-alive
 
                        peer.TimerStartKeepalive()
diff --git a/src/replay.go b/src/replay.go
new file mode 100644 (file)
index 0000000..49c7e08
--- /dev/null
@@ -0,0 +1,71 @@
+package main
+
+/* Implementation of RFC6479
+ * https://tools.ietf.org/html/rfc6479
+ *
+ * The implementation is not safe for concurrent use!
+ */
+
+const (
+       // See: https://golang.org/src/math/big/arith.go
+       _Wordm       = ^uintptr(0)
+       _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
+       _WordSize    = 1 << _WordLogSize
+)
+
+const (
+       CounterRedundantBitsLog = _WordLogSize + 3
+       CounterRedundantBits    = _WordSize * 8
+       CounterBitsTotal        = 2048
+       CounterWindowSize       = uint64(CounterBitsTotal - CounterRedundantBits)
+)
+
+const (
+       BacktrackWords = CounterBitsTotal / _WordSize
+)
+
+type ReplayFilter struct {
+       counter   uint64
+       backtrack [BacktrackWords]uintptr
+}
+
+func (filter *ReplayFilter) Init() {
+       filter.counter = 0
+       filter.backtrack[0] = 0
+}
+
+func (filter *ReplayFilter) ValidateCounter(counter uint64) bool {
+       if counter >= RejectAfterMessages {
+               return false
+       }
+
+       indexWord := counter >> CounterRedundantBitsLog
+
+       if counter > filter.counter {
+
+               // move window forward
+
+               current := filter.counter >> CounterRedundantBitsLog
+               diff := minUint64(indexWord-current, BacktrackWords)
+               for i := uint64(1); i <= diff; i++ {
+                       filter.backtrack[(current+i)%BacktrackWords] = 0
+               }
+               filter.counter = counter
+
+       } else if filter.counter-counter > CounterWindowSize {
+
+               // behind current window
+
+               return false
+       }
+
+       indexWord %= BacktrackWords
+       indexBit := counter & uint64(CounterRedundantBits-1)
+
+       // check and set bit
+
+       oldValue := filter.backtrack[indexWord]
+       newValue := oldValue | (1 << indexBit)
+       filter.backtrack[indexWord] = newValue
+       return oldValue != newValue
+}
diff --git a/src/replay_test.go b/src/replay_test.go
new file mode 100644 (file)
index 0000000..e75c5c1
--- /dev/null
@@ -0,0 +1,114 @@
+package main
+
+import (
+       "testing"
+)
+
+/* Ported from the linux kernel implementation
+ *
+ *
+ */
+
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+func TestReplay(t *testing.T) {
+       var filter ReplayFilter
+
+       T_LIM := CounterWindowSize + 1
+
+       testNumber := 0
+       T := func(n uint64, v bool) {
+               testNumber++
+               if filter.ValidateCounter(n) != v {
+                       t.Fatal("Test", testNumber, "failed", n, v)
+               }
+       }
+
+       filter.Init()
+
+       /*  1 */ T(0, true)
+       /*  2 */ T(1, true)
+       /*  3 */ T(1, false)
+       /*  4 */ T(9, true)
+       /*  5 */ T(8, true)
+       /*  6 */ T(7, true)
+       /*  7 */ T(7, false)
+       /*  8 */ T(T_LIM, true)
+       /*  9 */ T(T_LIM-1, true)
+       /* 10 */ T(T_LIM-1, false)
+       /* 11 */ T(T_LIM-2, true)
+       /* 12 */ T(2, true)
+       /* 13 */ T(2, false)
+       /* 14 */ T(T_LIM+16, true)
+       /* 15 */ T(3, false)
+       /* 16 */ T(T_LIM+16, false)
+       /* 17 */ T(T_LIM*4, true)
+       /* 18 */ T(T_LIM*4-(T_LIM-1), true)
+       /* 19 */ T(10, false)
+       /* 20 */ T(T_LIM*4-T_LIM, false)
+       /* 21 */ T(T_LIM*4-(T_LIM+1), false)
+       /* 22 */ T(T_LIM*4-(T_LIM-2), true)
+       /* 23 */ T(T_LIM*4+1-T_LIM, false)
+       /* 24 */ T(0, false)
+       /* 25 */ T(RejectAfterMessages, false)
+       /* 26 */ T(RejectAfterMessages-1, true)
+       /* 27 */ T(RejectAfterMessages, false)
+       /* 28 */ T(RejectAfterMessages-1, false)
+       /* 29 */ T(RejectAfterMessages-2, true)
+       /* 30 */ T(RejectAfterMessages+1, false)
+       /* 31 */ T(RejectAfterMessages+2, false)
+       /* 32 */ T(RejectAfterMessages-2, false)
+       /* 33 */ T(RejectAfterMessages-3, true)
+       /* 34 */ T(0, false)
+
+       t.Log("Bulk test 1")
+       filter.Init()
+       testNumber = 0
+       for i := uint64(1); i <= CounterWindowSize; i++ {
+               T(i, true)
+       }
+       T(0, true)
+       T(0, false)
+
+       t.Log("Bulk test 2")
+       filter.Init()
+       testNumber = 0
+       for i := uint64(2); i <= CounterWindowSize+1; i++ {
+               T(i, true)
+       }
+       T(1, true)
+       T(0, false)
+
+       t.Log("Bulk test 3")
+       filter.Init()
+       testNumber = 0
+       for i := CounterWindowSize + 1; i > 0; i-- {
+               T(i, true)
+       }
+
+       t.Log("Bulk test 4")
+       filter.Init()
+       testNumber = 0
+       for i := CounterWindowSize + 2; i > 1; i-- {
+               T(i, true)
+       }
+       T(0, false)
+
+       t.Log("Bulk test 5")
+       filter.Init()
+       testNumber = 0
+       for i := CounterWindowSize; i > 0; i-- {
+               T(i, true)
+       }
+       T(CounterWindowSize+1, true)
+       T(0, false)
+
+       t.Log("Bulk test 6")
+       filter.Init()
+       testNumber = 0
+       for i := CounterWindowSize; i > 0; i-- {
+               T(i, true)
+       }
+       T(0, true)
+       T(CounterWindowSize+1, true)
+}
index 26926c265bec38f733423659e88657cf6e391de2..70e0766da9fb83427686b0271e8b733f9f7b873b 100644 (file)
@@ -12,22 +12,15 @@ import (
  *
  */
 func (peer *Peer) KeepKeyFreshSending() {
-       send := func() bool {
-               peer.keyPairs.mutex.RLock()
-               defer peer.keyPairs.mutex.RUnlock()
-
-               kp := peer.keyPairs.current
-               if kp == nil {
-                       return false
-               }
-
-               if !kp.isInitiator {
-                       return false
-               }
-
-               nonce := atomic.LoadUint64(&kp.sendNonce)
-               return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
-       }()
+       kp := peer.keyPairs.Current()
+       if kp == nil {
+               return
+       }
+       if !kp.isInitiator {
+               return
+       }
+       nonce := atomic.LoadUint64(&kp.sendNonce)
+       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
        if send {
                signalSend(peer.signal.handshakeBegin)
        }
@@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() {
  *
  */
 func (peer *Peer) KeepKeyFreshReceiving() {
-       send := func() bool {
-               peer.keyPairs.mutex.RLock()
-               defer peer.keyPairs.mutex.RUnlock()
-
-               kp := peer.keyPairs.current
-               if kp == nil {
-                       return false
-               }
-
-               if !kp.isInitiator {
-                       return false
-               }
-
-               nonce := atomic.LoadUint64(&kp.sendNonce)
-               return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
-       }()
+       kp := peer.keyPairs.Current()
+       if kp == nil {
+               return
+       }
+       if !kp.isInitiator {
+               return
+       }
+       nonce := atomic.LoadUint64(&kp.sendNonce)
+       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
        if send {
                signalSend(peer.signal.handshakeBegin)
        }