]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
More pooling
authorJason A. Donenfeld <Jason@zx2c4.com>
Sat, 22 Sep 2018 04:29:02 +0000 (06:29 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 23 Sep 2018 22:37:43 +0000 (00:37 +0200)
device.go
pools.go [new file with mode: 0644]
receive.go
send.go

index bbcd0fc4afba0d836b6b57f2d84c660c237748ea..7cf9ba2253e12b6e94d4d622ef6ee9dbcc9c1f6e 100644 (file)
--- a/device.go
+++ b/device.go
@@ -19,8 +19,6 @@ const (
        DeviceRoutineNumberAdditional = 2
 )
 
-var preallocatedBuffers = 0
-
 type Device struct {
        isUp     AtomicBool // device is (going) up
        isClosed AtomicBool // device is closed? (acting as guard)
@@ -68,8 +66,12 @@ type Device struct {
        }
 
        pool struct {
-               messageBuffers *sync.Pool
-               reuseChan      chan interface{}
+               messageBufferPool        *sync.Pool
+               messageBufferReuseChan   chan *[MaxMessageSize]byte
+               inboundElementPool       *sync.Pool
+               inboundElementReuseChan  chan *QueueInboundElement
+               outboundElementPool      *sync.Pool
+               outboundElementReuseChan chan *QueueOutboundElement
        }
 
        queue struct {
@@ -245,22 +247,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        return nil
 }
 
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
-       if preallocatedBuffers == 0 {
-               return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
-       } else {
-               return (<-device.pool.reuseChan).(*[MaxMessageSize]byte)
-       }
-}
-
-func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
-       if preallocatedBuffers == 0 {
-               device.pool.messageBuffers.Put(msg)
-       } else {
-               device.pool.reuseChan <- msg
-       }
-}
-
 func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
        device := new(Device)
 
@@ -285,18 +271,7 @@ func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
        device.indexTable.Init()
        device.allowedips.Reset()
 
-       if preallocatedBuffers == 0 {
-               device.pool.messageBuffers = &sync.Pool{
-                       New: func() interface{} {
-                               return new([MaxMessageSize]byte)
-                       },
-               }
-       } else {
-               device.pool.reuseChan = make(chan interface{}, preallocatedBuffers)
-               for i := 0; i < preallocatedBuffers; i += 1 {
-                       device.pool.reuseChan <- new([MaxMessageSize]byte)
-               }
-       }
+       device.PopulatePools()
 
        // create queues
 
diff --git a/pools.go b/pools.go
new file mode 100644 (file)
index 0000000..fe219f4
--- /dev/null
+++ b/pools.go
@@ -0,0 +1,91 @@
+/* SPDX-License-Identifier: GPL-2.0
+ *
+ * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import "sync"
+
+var preallocatedBuffers = 0
+
+func (device *Device) PopulatePools() {
+       if preallocatedBuffers == 0 {
+               device.pool.messageBufferPool = &sync.Pool{
+                       New: func() interface{} {
+                               return new([MaxMessageSize]byte)
+                       },
+               }
+               device.pool.inboundElementPool = &sync.Pool{
+                       New: func() interface{} {
+                               return new(QueueInboundElement)
+                       },
+               }
+               device.pool.outboundElementPool = &sync.Pool{
+                       New: func() interface{} {
+                               return new(QueueOutboundElement)
+                       },
+               }
+       } else {
+               device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, preallocatedBuffers)
+               for i := 0; i < preallocatedBuffers; i += 1 {
+                       device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
+               }
+               device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, preallocatedBuffers)
+               for i := 0; i < preallocatedBuffers; i += 1 {
+                       device.pool.inboundElementReuseChan <- new(QueueInboundElement)
+               }
+               device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, preallocatedBuffers)
+               for i := 0; i < preallocatedBuffers; i += 1 {
+                       device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+               }
+       }
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+       if preallocatedBuffers == 0 {
+               return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
+       } else {
+               return <-device.pool.messageBufferReuseChan
+       }
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+       if preallocatedBuffers == 0 {
+               device.pool.messageBufferPool.Put(msg)
+       } else {
+               device.pool.messageBufferReuseChan <- msg
+       }
+}
+
+func (device *Device) GetInboundElement() *QueueInboundElement {
+       if preallocatedBuffers == 0 {
+               return device.pool.inboundElementPool.Get().(*QueueInboundElement)
+       } else {
+               return <-device.pool.inboundElementReuseChan
+       }
+}
+
+func (device *Device) PutInboundElement(msg *QueueInboundElement) {
+       if preallocatedBuffers == 0 {
+               device.pool.inboundElementPool.Put(msg)
+       } else {
+               device.pool.inboundElementReuseChan <- msg
+       }
+}
+
+func (device *Device) GetOutboundElement() *QueueOutboundElement {
+       if preallocatedBuffers == 0 {
+               return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
+       } else {
+               return <-device.pool.outboundElementReuseChan
+       }
+}
+
+func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
+       if preallocatedBuffers == 0 {
+               device.pool.outboundElementPool.Put(msg)
+       } else {
+               device.pool.outboundElementReuseChan <- msg
+       }
+}
index 9bf3af3ca81fc812e9cd2563bc7d6e731e6527ed..ab8691395fb0e816f654ed38cca779214ffe0883 100644 (file)
@@ -55,6 +55,7 @@ func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueIn
                        return false
                }
        default:
+               device.PutInboundElement(element)
                return false
        }
 }
@@ -168,15 +169,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        }
 
                        // create work element
-
                        peer := value.peer
-                       elem := &QueueInboundElement{
-                               packet:   packet,
-                               buffer:   buffer,
-                               keypair:  keypair,
-                               dropped:  AtomicFalse,
-                               endpoint: endpoint,
-                       }
+                       elem := device.GetInboundElement()
+                       elem.packet = packet
+                       elem.buffer = buffer
+                       elem.keypair = keypair
+                       elem.dropped = AtomicFalse
+                       elem.endpoint = endpoint
+                       elem.counter = 0
+                       elem.mutex = sync.Mutex{}
                        elem.mutex.Lock()
 
                        // add to decryption queues
@@ -246,6 +247,7 @@ func (device *Device) RoutineDecryption() {
                        // check if dropped
 
                        if elem.IsDropped() {
+                               device.PutInboundElement(elem)
                                continue
                        }
 
@@ -280,7 +282,6 @@ func (device *Device) RoutineDecryption() {
                                elem.Drop()
                                device.PutMessageBuffer(elem.buffer)
                                elem.buffer = nil
-                               elem.mutex.Unlock()
                        }
                        elem.mutex.Unlock()
                }
@@ -487,12 +488,16 @@ func (peer *Peer) RoutineSequentialReceiver() {
        logDebug := device.log.Debug
 
        var elem *QueueInboundElement
+       var ok bool
 
        defer func() {
                logDebug.Println(peer, "- Routine: sequential receiver - stopped")
                peer.routines.stopping.Done()
-               if elem != nil && elem.buffer != nil {
-                       device.PutMessageBuffer(elem.buffer)
+               if elem != nil {
+                       if elem.buffer != nil {
+                               device.PutMessageBuffer(elem.buffer)
+                       }
+                       device.PutInboundElement(elem)
                }
        }()
 
@@ -501,8 +506,11 @@ func (peer *Peer) RoutineSequentialReceiver() {
        peer.routines.starting.Done()
 
        for {
-               if elem != nil && elem.buffer != nil {
-                       device.PutMessageBuffer(elem.buffer)
+               if elem != nil {
+                       if elem.buffer != nil {
+                               device.PutMessageBuffer(elem.buffer)
+                       }
+                       device.PutInboundElement(elem)
                }
 
                select {
@@ -510,7 +518,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                case <-peer.routines.stop:
                        return
 
-               case elem, ok := <-peer.queue.inbound:
+               case elem, ok = <-peer.queue.inbound:
 
                        if !ok {
                                return
@@ -621,9 +629,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        offset := MessageTransportOffsetContent
                        atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
-                       _, err := device.tun.device.Write(
-                               elem.buffer[:offset+len(elem.packet)],
-                               offset)
+                       _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
                        if err != nil {
                                logError.Println("Failed to write packet to TUN device:", err)
                        }
diff --git a/send.go b/send.go
index 24e2f39549818280013debf6e2baa11907ec7412..fa8404347aa341c5e5829c19f8660cde40558563 100644 (file)
--- a/send.go
+++ b/send.go
@@ -52,10 +52,14 @@ type QueueOutboundElement struct {
 }
 
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
-       return &QueueOutboundElement{
-               dropped: AtomicFalse,
-               buffer:  device.GetMessageBuffer(),
-       }
+       elem := device.GetOutboundElement()
+       elem.dropped = AtomicFalse
+       elem.buffer = device.GetMessageBuffer()
+       elem.mutex = sync.Mutex{}
+       elem.nonce = 0
+       elem.keypair = nil
+       elem.peer = nil
+       return elem
 }
 
 func (elem *QueueOutboundElement) Drop() {
@@ -75,6 +79,7 @@ func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundEle
                        select {
                        case old := <-queue:
                                device.PutMessageBuffer(old.buffer)
+                               device.PutOutboundElement(old)
                        default:
                        }
                }
@@ -94,6 +99,7 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
                }
        default:
                element.peer.device.PutMessageBuffer(element.buffer)
+               element.peer.device.PutOutboundElement(element)
        }
 }
 
@@ -111,6 +117,7 @@ func (peer *Peer) SendKeepalive() bool {
                return true
        default:
                peer.device.PutMessageBuffer(elem.buffer)
+               peer.device.PutOutboundElement(elem)
                return false
        }
 }
@@ -236,8 +243,6 @@ func (peer *Peer) keepKeyFreshSending() {
  */
 func (device *Device) RoutineReadFromTUN() {
 
-       elem := device.NewOutboundElement()
-
        logDebug := device.log.Debug
        logError := device.log.Error
 
@@ -249,7 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
        logDebug.Println("Routine: TUN reader - started")
        device.state.starting.Done()
 
+       var elem *QueueOutboundElement
+
        for {
+               if elem != nil {
+                       device.PutMessageBuffer(elem.buffer)
+                       device.PutOutboundElement(elem)
+               }
+               elem = device.NewOutboundElement()
 
                // read packet
 
@@ -262,6 +274,7 @@ func (device *Device) RoutineReadFromTUN() {
                                device.Close()
                        }
                        device.PutMessageBuffer(elem.buffer)
+                       device.PutOutboundElement(elem)
                        return
                }
 
@@ -304,7 +317,7 @@ func (device *Device) RoutineReadFromTUN() {
                                peer.SendHandshakeInitiation(false)
                        }
                        addToNonceQueue(peer.queue.nonce, elem, device)
-                       elem = device.NewOutboundElement()
+                       elem = nil
                }
        }
 }
@@ -339,6 +352,7 @@ func (peer *Peer) RoutineNonce() {
                        select {
                        case elem := <-peer.queue.nonce:
                                device.PutMessageBuffer(elem.buffer)
+                               device.PutOutboundElement(elem)
                        default:
                                return
                        }
@@ -399,11 +413,13 @@ func (peer *Peer) RoutineNonce() {
 
                                case <-peer.signals.flushNonceQueue:
                                        device.PutMessageBuffer(elem.buffer)
+                                       device.PutOutboundElement(elem)
                                        flush()
                                        goto NextPacket
 
                                case <-peer.routines.stop:
                                        device.PutMessageBuffer(elem.buffer)
+                                       device.PutOutboundElement(elem)
                                        return
                                }
                        }
@@ -419,6 +435,7 @@ func (peer *Peer) RoutineNonce() {
                        if elem.nonce >= RejectAfterMessages {
                                atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
                                device.PutMessageBuffer(elem.buffer)
+                               device.PutOutboundElement(elem)
                                goto NextPacket
                        }
 
@@ -468,6 +485,7 @@ func (device *Device) RoutineEncryption() {
                        // check if dropped
 
                        if elem.IsDropped() {
+                               device.PutOutboundElement(elem)
                                continue
                        }
 
@@ -544,6 +562,7 @@ func (peer *Peer) RoutineSequentialSender() {
 
                        elem.mutex.Lock()
                        if elem.IsDropped() {
+                               device.PutOutboundElement(elem)
                                continue
                        }
 
@@ -555,6 +574,7 @@ func (peer *Peer) RoutineSequentialSender() {
                        length := uint64(len(elem.packet))
                        err := peer.SendBuffer(elem.packet)
                        device.PutMessageBuffer(elem.buffer)
+                       device.PutOutboundElement(elem)
                        if err != nil {
                                logError.Println(peer, "- Failed to send data packet", err)
                                continue