]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use a waiting sync.Pool instead of a channel
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 2 Feb 2021 17:37:49 +0000 (18:37 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 2 Feb 2021 18:32:13 +0000 (19:32 +0100)
Channels are FIFO which means we have guaranteed cache misses.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/alignment_test.go
device/device.go
device/pools.go
device/pools_test.go [new file with mode: 0644]

index 5587cbe1317671679c141e7d6337b6d7427b3f66..46baeb1e6559b518e7fb0f485cea7742dd7e60d2 100644 (file)
@@ -42,7 +42,6 @@ func TestPeerAlignment(t *testing.T) {
        checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
 }
 
-
 // TestDeviceAlignment checks that atomically-accessed fields are
 // aligned to 64-bit boundaries, as required by the atomic package.
 //
index bac361ec4ad2d71048141d90dde6db2c8ec4056f..5f360362a82cb3e5ec4b3aa6f30abd2e1709f0ab 100644 (file)
@@ -67,12 +67,9 @@ type Device struct {
        }
 
        pool struct {
-               messageBufferPool        *sync.Pool
-               messageBufferReuseChan   chan *[MaxMessageSize]byte
-               inboundElementPool       *sync.Pool
-               inboundElementReuseChan  chan *QueueInboundElement
-               outboundElementPool      *sync.Pool
-               outboundElementReuseChan chan *QueueOutboundElement
+               messageBuffers   *WaitPool
+               inboundElements  *WaitPool
+               outboundElements *WaitPool
        }
 
        queue struct {
index eb6d6beb55ee15a6c621e57dde132cd7a998498b..f1d1fa099c4554cf44a34a560a07839e49c94bc2 100644 (file)
@@ -5,87 +5,80 @@
 
 package device
 
-import "sync"
+import (
+       "sync"
+       "sync/atomic"
+)
 
-func (device *Device) PopulatePools() {
-       if PreallocatedBuffersPerPool == 0 {
-               device.pool.messageBufferPool = &sync.Pool{
-                       New: func() interface{} {
-                               return new([MaxMessageSize]byte)
-                       },
-               }
-               device.pool.inboundElementPool = &sync.Pool{
-                       New: func() interface{} {
-                               return new(QueueInboundElement)
-                       },
-               }
-               device.pool.outboundElementPool = &sync.Pool{
-                       New: func() interface{} {
-                               return new(QueueOutboundElement)
-                       },
-               }
-       } else {
-               device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
-               for i := 0; i < PreallocatedBuffersPerPool; i++ {
-                       device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
-               }
-               device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
-               for i := 0; i < PreallocatedBuffersPerPool; i++ {
-                       device.pool.inboundElementReuseChan <- new(QueueInboundElement)
-               }
-               device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
-               for i := 0; i < PreallocatedBuffersPerPool; i++ {
-                       device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+type WaitPool struct {
+       pool  sync.Pool
+       cond  sync.Cond
+       lock  sync.Mutex
+       count uint32
+       max   uint32
+}
+
+func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
+       p := &WaitPool{pool: sync.Pool{New: new}, max: max}
+       p.cond = sync.Cond{L: &p.lock}
+       return p
+}
+
+func (p *WaitPool) Get() interface{} {
+       if p.max != 0 {
+               p.lock.Lock()
+               for atomic.LoadUint32(&p.count) >= p.max {
+                       p.cond.Wait()
                }
+               atomic.AddUint32(&p.count, 1)
+               p.lock.Unlock()
        }
+       return p.pool.Get()
 }
 
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
-       if PreallocatedBuffersPerPool == 0 {
-               return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
-       } else {
-               return <-device.pool.messageBufferReuseChan
+func (p *WaitPool) Put(x interface{}) {
+       p.pool.Put(x)
+       if p.max == 0 {
+               return
        }
+       atomic.AddUint32(&p.count, ^uint32(0))
+       p.cond.Signal()
+}
+
+func (device *Device) PopulatePools() {
+       device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+               return new([MaxMessageSize]byte)
+       })
+       device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+               return new(QueueInboundElement)
+       })
+       device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+               return new(QueueOutboundElement)
+       })
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+       return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
 }
 
 func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
-       if PreallocatedBuffersPerPool == 0 {
-               device.pool.messageBufferPool.Put(msg)
-       } else {
-               device.pool.messageBufferReuseChan <- msg
-       }
+       device.pool.messageBuffers.Put(msg)
 }
 
 func (device *Device) GetInboundElement() *QueueInboundElement {
-       if PreallocatedBuffersPerPool == 0 {
-               return device.pool.inboundElementPool.Get().(*QueueInboundElement)
-       } else {
-               return <-device.pool.inboundElementReuseChan
-       }
+       return device.pool.inboundElements.Get().(*QueueInboundElement)
 }
 
 func (device *Device) PutInboundElement(elem *QueueInboundElement) {
        elem.clearPointers()
-       if PreallocatedBuffersPerPool == 0 {
-               device.pool.inboundElementPool.Put(elem)
-       } else {
-               device.pool.inboundElementReuseChan <- elem
-       }
+       device.pool.inboundElements.Put(elem)
 }
 
 func (device *Device) GetOutboundElement() *QueueOutboundElement {
-       if PreallocatedBuffersPerPool == 0 {
-               return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
-       } else {
-               return <-device.pool.outboundElementReuseChan
-       }
+       return device.pool.outboundElements.Get().(*QueueOutboundElement)
 }
 
 func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
        elem.clearPointers()
-       if PreallocatedBuffersPerPool == 0 {
-               device.pool.outboundElementPool.Put(elem)
-       } else {
-               device.pool.outboundElementReuseChan <- elem
-       }
+       device.pool.outboundElements.Put(elem)
 }
diff --git a/device/pools_test.go b/device/pools_test.go
new file mode 100644 (file)
index 0000000..e6cbac5
--- /dev/null
@@ -0,0 +1,60 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+       "math/rand"
+       "runtime"
+       "sync"
+       "sync/atomic"
+       "testing"
+       "time"
+)
+
+func TestWaitPool(t *testing.T) {
+       var wg sync.WaitGroup
+       trials := int32(100000)
+       workers := runtime.NumCPU() + 2
+       if workers-4 <= 0 {
+               t.Skip("Not enough cores")
+       }
+       p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+       wg.Add(workers)
+       max := uint32(0)
+       updateMax := func() {
+               count := atomic.LoadUint32(&p.count)
+               if count > p.max {
+                       t.Errorf("count (%d) > max (%d)", count, p.max)
+               }
+               for {
+                       old := atomic.LoadUint32(&max)
+                       if count <= old {
+                               break
+                       }
+                       if atomic.CompareAndSwapUint32(&max, old, count) {
+                               break
+                       }
+               }
+       }
+       for i := 0; i < workers; i++ {
+               go func() {
+                       defer wg.Done()
+                       for atomic.AddInt32(&trials, -1) > 0 {
+                               updateMax()
+                               x := p.Get()
+                               updateMax()
+                               time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+                               updateMax()
+                               p.Put(x)
+                               updateMax()
+                       }
+               }()
+       }
+       wg.Wait()
+       if max != p.max {
+               t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
+       }
+}