]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
replay: clean up internals and better documentation
authorRiobard Zhan <me@riobard.com>
Wed, 9 Sep 2020 17:55:24 +0000 (01:55 +0800)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 14 Oct 2020 08:46:00 +0000 (10:46 +0200)
Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
replay/replay.go
replay/replay_test.go

index 85647f5aacda5a2ae52a81db7814e7f467412130..8685712a853b683dbedfe8063bdb9c1147891a31 100644 (file)
@@ -3,81 +3,60 @@
  * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
  */
 
+// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
 package replay
 
-/* Implementation of RFC6479
- * https://tools.ietf.org/html/rfc6479
- *
- * The implementation is not safe for concurrent use!
- */
-
-const (
-       // See: https://golang.org/src/math/big/arith.go
-       _Wordm       = ^uintptr(0)
-       _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
-       _WordSize    = 1 << _WordLogSize
-)
+type block uint64
 
 const (
-       CounterRedundantBitsLog = _WordLogSize + 3
-       CounterRedundantBits    = _WordSize * 8
-       CounterBitsTotal        = 8192
-       CounterWindowSize       = uint64(CounterBitsTotal - CounterRedundantBits)
+       blockBitLog = 6                // 1<<6 == 64 bits
+       blockBits   = 1 << blockBitLog // must be power of 2
+       ringBlocks  = 1 << 7           // must be power of 2
+       windowSize  = (ringBlocks - 1) * blockBits
+       blockMask   = ringBlocks - 1
+       bitMask     = blockBits - 1
 )
 
-const (
-       BacktrackWords = CounterBitsTotal / 8 / _WordSize
-)
-
-func minUint64(a uint64, b uint64) uint64 {
-       if a > b {
-               return b
-       }
-       return a
-}
-
+// A ReplayFilter rejects replayed messages by checking if message counter value is
+// within a sliding window of previously received messages.
+// The zero value for ReplayFilter is an empty filter ready to use.
+// Filters are unsafe for concurrent use.
 type ReplayFilter struct {
-       counter   uint64
-       backtrack [BacktrackWords]uintptr
+       last uint64
+       ring [ringBlocks]block
 }
 
-func (filter *ReplayFilter) Init() {
-       filter.counter = 0
-       filter.backtrack[0] = 0
+// Init resets the filter to empty state.
+func (f *ReplayFilter) Init() {
+       f.last = 0
+       f.ring[0] = 0
 }
 
-func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
+// ValidateCounter checks if the counter should be accepted.
+// Overlimit counters (>= limit) are always rejected.
+func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
        if counter >= limit {
                return false
        }
-
-       indexWord := counter >> CounterRedundantBitsLog
-
-       if counter > filter.counter {
-
-               // move window forward
-
-               current := filter.counter >> CounterRedundantBitsLog
-               diff := minUint64(indexWord-current, BacktrackWords)
-               for i := uint64(1); i <= diff; i++ {
-                       filter.backtrack[(current+i)%BacktrackWords] = 0
+       indexBlock := counter >> blockBitLog
+       if counter > f.last { // move window forward
+               current := f.last >> blockBitLog
+               diff := indexBlock - current
+               if diff > ringBlocks {
+                       diff = ringBlocks // cap diff to clear the whole ring
                }
-               filter.counter = counter
-
-       } else if filter.counter-counter > CounterWindowSize {
-
-               // behind current window
-
+               for i := current + 1; i <= current+diff; i++ {
+                       f.ring[i&blockMask] = 0
+               }
+               f.last = counter
+       } else if f.last-counter > windowSize { // behind current window
                return false
        }
-
-       indexWord %= BacktrackWords
-       indexBit := counter & uint64(CounterRedundantBits-1)
-
        // check and set bit
-
-       oldValue := filter.backtrack[indexWord]
-       newValue := oldValue | (1 << indexBit)
-       filter.backtrack[indexWord] = newValue
-       return oldValue != newValue
+       indexBlock &= blockMask
+       indexBit := counter & bitMask
+       old := f.ring[indexBlock]
+       new := old | 1<<indexBit
+       f.ring[indexBlock] = new
+       return old != new
 }
index ceae2f3c7159095d25608351ec8654f05c622aea..5af66ffaad9d517df5f7164dc5d2e7341009a992 100644 (file)
@@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
 func TestReplay(t *testing.T) {
        var filter ReplayFilter
 
-       T_LIM := CounterWindowSize + 1
+       const T_LIM = windowSize + 1
 
        testNumber := 0
-       T := func(n uint64, v bool) {
+       T := func(n uint64, expected bool) {
                testNumber++
-               if filter.ValidateCounter(n, RejectAfterMessages) != v {
-                       t.Fatal("Test", testNumber, "failed", n, v)
+               if filter.ValidateCounter(n, RejectAfterMessages) != expected {
+                       t.Fatal("Test", testNumber, "failed", n, expected)
                }
        }
 
@@ -69,7 +69,7 @@ func TestReplay(t *testing.T) {
        t.Log("Bulk test 1")
        filter.Init()
        testNumber = 0
-       for i := uint64(1); i <= CounterWindowSize; i++ {
+       for i := uint64(1); i <= windowSize; i++ {
                T(i, true)
        }
        T(0, true)
@@ -78,7 +78,7 @@ func TestReplay(t *testing.T) {
        t.Log("Bulk test 2")
        filter.Init()
        testNumber = 0
-       for i := uint64(2); i <= CounterWindowSize+1; i++ {
+       for i := uint64(2); i <= windowSize+1; i++ {
                T(i, true)
        }
        T(1, true)
@@ -87,14 +87,14 @@ func TestReplay(t *testing.T) {
        t.Log("Bulk test 3")
        filter.Init()
        testNumber = 0
-       for i := CounterWindowSize + 1; i > 0; i-- {
+       for i := uint64(windowSize + 1); i > 0; i-- {
                T(i, true)
        }
 
        t.Log("Bulk test 4")
        filter.Init()
        testNumber = 0
-       for i := CounterWindowSize + 2; i > 1; i-- {
+       for i := uint64(windowSize + 2); i > 1; i-- {
                T(i, true)
        }
        T(0, false)
@@ -102,18 +102,18 @@ func TestReplay(t *testing.T) {
        t.Log("Bulk test 5")
        filter.Init()
        testNumber = 0
-       for i := CounterWindowSize; i > 0; i-- {
+       for i := uint64(windowSize); i > 0; i-- {
                T(i, true)
        }
-       T(CounterWindowSize+1, true)
+       T(windowSize+1, true)
        T(0, false)
 
        t.Log("Bulk test 6")
        filter.Init()
        testNumber = 0
-       for i := CounterWindowSize; i > 0; i-- {
+       for i := uint64(windowSize); i > 0; i-- {
                T(i, true)
        }
        T(0, true)
-       T(CounterWindowSize+1, true)
+       T(windowSize+1, true)
 }