]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: fix WaitPool sync.Cond usage
authorJordan Whited <jordan@tailscale.com>
Thu, 27 Jun 2024 15:43:41 +0000 (08:43 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 4 May 2025 16:11:00 +0000 (18:11 +0200)
The sync.Locker used with a sync.Cond must be acquired when changing
the associated condition, otherwise there is a window within
sync.Cond.Wait() where a wake-up may be missed.

Fixes: 4846070 ("device: use a waiting sync.Pool instead of a channel")
Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/pools.go
device/pools_test.go

index 94f3dc7e6b9271007812298877978ac213d7f569..55d2be7df69a1e27fbf47c120ed76f5173001871 100644 (file)
@@ -7,14 +7,13 @@ package device
 
 import (
        "sync"
-       "sync/atomic"
 )
 
 type WaitPool struct {
        pool  sync.Pool
        cond  sync.Cond
        lock  sync.Mutex
-       count atomic.Uint32
+       count uint32 // Get calls not yet Put back
        max   uint32
 }
 
@@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
 func (p *WaitPool) Get() any {
        if p.max != 0 {
                p.lock.Lock()
-               for p.count.Load() >= p.max {
+               for p.count >= p.max {
                        p.cond.Wait()
                }
-               p.count.Add(1)
+               p.count++
                p.lock.Unlock()
        }
        return p.pool.Get()
@@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) {
        if p.max == 0 {
                return
        }
-       p.count.Add(^uint32(0))
+       p.lock.Lock()
+       defer p.lock.Unlock()
+       p.count--
        p.cond.Signal()
 }
 
index 82d7493e148a4700749e3778a780788c1f4c34f7..538230b94c8008a96f7c27a9dc3c27650df9294b 100644 (file)
@@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
        wg.Add(workers)
        var max atomic.Uint32
        updateMax := func() {
-               count := p.count.Load()
+               p.lock.Lock()
+               count := p.count
+               p.lock.Unlock()
                if count > p.max {
                        t.Errorf("count (%d) > max (%d)", count, p.max)
                }