]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
ratelimiter: use a fake clock in tests and style cleanups
authorDavid Crawshaw <crawshaw@tailscale.com>
Sun, 8 Dec 2019 23:22:31 +0000 (18:22 -0500)
committerDavid Crawshaw <david@zentus.com>
Mon, 30 Mar 2020 07:38:36 +0000 (18:38 +1100)
The existing test would occasionally flake out with:

--- FAIL: TestRatelimiter (0.12s)
    ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
FAIL
FAIL    golang.zx2c4.com/wireguard/ratelimiter  0.171s

The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.

While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.

Passes `go test -race -count=10000 ./ratelimiter`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
ratelimiter/ratelimiter.go
ratelimiter/ratelimiter_test.go

index 772c45aed6636765665b8ac99418e09aca4af6b2..a6d0ea26c28098d10bc9330ff5f29648e9290370 100644 (file)
@@ -20,21 +20,23 @@ const (
 )
 
 type RatelimiterEntry struct {
-       sync.Mutex
+       mu       sync.Mutex
        lastTime time.Time
        tokens   int64
 }
 
 type Ratelimiter struct {
-       sync.RWMutex
-       stopReset chan struct{}
+       mu      sync.RWMutex
+       timeNow func() time.Time
+
+       stopReset chan struct{} // send to reset, close to stop
        tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
        tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
 }
 
 func (rate *Ratelimiter) Close() {
-       rate.Lock()
-       defer rate.Unlock()
+       rate.mu.Lock()
+       defer rate.mu.Unlock()
 
        if rate.stopReset != nil {
                close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
 }
 
 func (rate *Ratelimiter) Init() {
-       rate.Lock()
-       defer rate.Unlock()
+       rate.mu.Lock()
+       defer rate.mu.Unlock()
 
-       // stop any ongoing garbage collection routine
+       if rate.timeNow == nil {
+               rate.timeNow = time.Now
+       }
 
+       // stop any ongoing garbage collection routine
        if rate.stopReset != nil {
                close(rate.stopReset)
        }
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
        rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
        rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
 
-       // start garbage collection routine
+       stopReset := rate.stopReset // store in case Init is called again.
 
+       // Start garbage collection routine.
        go func() {
                ticker := time.NewTicker(time.Second)
                ticker.Stop()
                for {
                        select {
-                       case _, ok := <-rate.stopReset:
+                       case _, ok := <-stopReset:
                                ticker.Stop()
-                               if ok {
-                                       ticker = time.NewTicker(time.Second)
-                               } else {
+                               if !ok {
                                        return
                                }
+                               ticker = time.NewTicker(time.Second)
                        case <-ticker.C:
-                               func() {
-                                       rate.Lock()
-                                       defer rate.Unlock()
-
-                                       for key, entry := range rate.tableIPv4 {
-                                               entry.Lock()
-                                               if time.Since(entry.lastTime) > garbageCollectTime {
-                                                       delete(rate.tableIPv4, key)
-                                               }
-                                               entry.Unlock()
-                                       }
-
-                                       for key, entry := range rate.tableIPv6 {
-                                               entry.Lock()
-                                               if time.Since(entry.lastTime) > garbageCollectTime {
-                                                       delete(rate.tableIPv6, key)
-                                               }
-                                               entry.Unlock()
-                                       }
-
-                                       if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
-                                               ticker.Stop()
-                                       }
-                               }()
+                               if rate.cleanup() {
+                                       ticker.Stop()
+                               }
                        }
                }
        }()
 }
 
+func (rate *Ratelimiter) cleanup() (empty bool) {
+       rate.mu.Lock()
+       defer rate.mu.Unlock()
+
+       for key, entry := range rate.tableIPv4 {
+               entry.mu.Lock()
+               if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+                       delete(rate.tableIPv4, key)
+               }
+               entry.mu.Unlock()
+       }
+
+       for key, entry := range rate.tableIPv6 {
+               entry.mu.Lock()
+               if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+                       delete(rate.tableIPv6, key)
+               }
+               entry.mu.Unlock()
+       }
+
+       return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+}
+
 func (rate *Ratelimiter) Allow(ip net.IP) bool {
        var entry *RatelimiterEntry
        var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
        IPv4 := ip.To4()
        IPv6 := ip.To16()
 
-       rate.RLock()
+       rate.mu.RLock()
 
        if IPv4 != nil {
                copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
                entry = rate.tableIPv6[keyIPv6]
        }
 
-       rate.RUnlock()
+       rate.mu.RUnlock()
 
        // make new entry if not found
 
        if entry == nil {
                entry = new(RatelimiterEntry)
                entry.tokens = maxTokens - packetCost
-               entry.lastTime = time.Now()
-               rate.Lock()
+               entry.lastTime = rate.timeNow()
+               rate.mu.Lock()
                if IPv4 != nil {
                        rate.tableIPv4[keyIPv4] = entry
                        if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
                                rate.stopReset <- struct{}{}
                        }
                }
-               rate.Unlock()
+               rate.mu.Unlock()
                return true
        }
 
        // add tokens to entry
 
-       entry.Lock()
-       now := time.Now()
+       entry.mu.Lock()
+       now := rate.timeNow()
        entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
        entry.lastTime = now
        if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
 
        if entry.tokens > packetCost {
                entry.tokens -= packetCost
-               entry.Unlock()
+               entry.mu.Unlock()
                return true
        }
-       entry.Unlock()
+       entry.mu.Unlock()
        return false
 }
index 659bdfb6e4d4c986c65f3a13971daff1947585c3..25d5d63a1b3fd7d3eedfdc5bf30cd3676657ae42 100644 (file)
@@ -11,22 +11,21 @@ import (
        "time"
 )
 
-type RatelimiterResult struct {
+type result struct {
        allowed bool
        text    string
        wait    time.Duration
 }
 
 func TestRatelimiter(t *testing.T) {
+       var rate Ratelimiter
+       var expectedResults []result
 
-       var ratelimiter Ratelimiter
-       var expectedResults []RatelimiterResult
-
-       Nano := func(nano int64) time.Duration {
+       nano := func(nano int64) time.Duration {
                return time.Nanosecond * time.Duration(nano)
        }
 
-       Add := func(res RatelimiterResult) {
+       add := func(res result) {
                expectedResults = append(
                        expectedResults,
                        res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
        }
 
        for i := 0; i < packetsBurstable; i++ {
-               Add(RatelimiterResult{
+               add(result{
                        allowed: true,
                        text:    "initial burst",
                })
        }
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: false,
                text:    "after burst",
        })
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: true,
-               wait:    Nano(time.Second.Nanoseconds() / packetsPerSecond),
+               wait:    nano(time.Second.Nanoseconds() / packetsPerSecond),
                text:    "filling tokens for single packet",
        })
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: false,
                text:    "not having refilled enough",
        })
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: true,
-               wait:    2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+               wait:    2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
                text:    "filling tokens for two packet burst",
        })
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: true,
                text:    "second packet in 2 packet burst",
        })
 
-       Add(RatelimiterResult{
+       add(result{
                allowed: false,
                text:    "packet following 2 packet burst",
        })
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
                net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
        }
 
-       ratelimiter.Init()
+       now := time.Now()
+       rate.timeNow = func() time.Time {
+               return now
+       }
+       defer func() {
+               // Lock to avoid data race with cleanup goroutine from Init.
+               rate.mu.Lock()
+               defer rate.mu.Unlock()
+
+               rate.timeNow = time.Now
+       }()
+       timeSleep := func(d time.Duration) {
+               now = now.Add(d + 1)
+               rate.cleanup()
+       }
+
+       rate.Init()
+       defer rate.Close()
 
        for i, res := range expectedResults {
-               time.Sleep(res.wait)
+               timeSleep(res.wait)
                for _, ip := range ips {
-                       allowed := ratelimiter.Allow(ip)
+                       allowed := rate.Allow(ip)
                        if allowed != res.allowed {
-                               t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+                               t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
                        }
                }
        }