]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Change queueing drop order and fix memory leaks
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 16 Sep 2018 19:50:58 +0000 (21:50 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 16 Sep 2018 19:50:58 +0000 (21:50 +0200)
If the queues are full, we drop the present packet, which is better for
network traffic flow. Also, we try to fix up the memory leaks with not
putting buffers from our shared pool.

receive.go
send.go

index b23c5e05c94e12afe311dfd9b450a670f8d8a295..6b6543c2d094a8b2df2a2b803cba8eaa4d1fcdf5 100644 (file)
@@ -43,59 +43,28 @@ func (elem *QueueInboundElement) IsDropped() bool {
        return atomic.LoadInt32(&elem.dropped) == AtomicTrue
 }
 
-func (device *Device) addToInboundQueue(
-       queue chan *QueueInboundElement,
-       element *QueueInboundElement,
-) {
-       for {
+func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
+       select {
+       case inboundQueue <- element:
                select {
-               case queue <- element:
-                       return
+               case decryptionQueue <- element:
+                       return true
                default:
-                       select {
-                       case old := <-queue:
-                               old.Drop()
-                       default:
-                       }
+                       element.Drop()
+                       element.mutex.Unlock()
+                       return false
                }
+       default:
+               return false
        }
 }
 
-func (device *Device) addToDecryptionQueue(
-       queue chan *QueueInboundElement,
-       element *QueueInboundElement,
-) {
-       for {
-               select {
-               case queue <- element:
-                       return
-               default:
-                       select {
-                       case old := <-queue:
-                               // drop & release to potential consumer
-                               old.Drop()
-                               old.mutex.Unlock()
-                       default:
-                       }
-               }
-       }
-}
-
-func (device *Device) addToHandshakeQueue(
-       queue chan QueueHandshakeElement,
-       element QueueHandshakeElement,
-) {
-       for {
-               select {
-               case queue <- element:
-                       return
-               default:
-                       select {
-                       case elem := <-queue:
-                               device.PutMessageBuffer(elem.buffer)
-                       default:
-                       }
-               }
+func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
+       select {
+       case queue <- element:
+               return true
+       default:
+               return false
        }
 }
 
@@ -154,6 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                }
 
                if err != nil {
+                       device.PutMessageBuffer(buffer)
                        return
                }
 
@@ -212,9 +182,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        // add to decryption queues
 
                        if peer.isRunning.Get() {
-                               device.addToDecryptionQueue(device.queue.decryption, elem)
-                               device.addToInboundQueue(peer.queue.inbound, elem)
-                               buffer = device.GetMessageBuffer()
+                               if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
+                                       buffer = device.GetMessageBuffer()
+                               }
                        }
 
                        continue
@@ -235,7 +205,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                }
 
                if okay {
-                       device.addToHandshakeQueue(
+                       if (device.addToHandshakeQueue(
                                device.queue.handshake,
                                QueueHandshakeElement{
                                        msgType:  msgType,
@@ -243,8 +213,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                                        packet:   packet,
                                        endpoint: endpoint,
                                },
-                       )
-                       buffer = device.GetMessageBuffer()
+                       )) {
+                               buffer = device.GetMessageBuffer()
+                       }
                }
        }
 }
@@ -307,6 +278,8 @@ func (device *Device) RoutineDecryption() {
                        )
                        if err != nil {
                                elem.Drop()
+                               device.PutMessageBuffer(elem.buffer)
+                               elem.mutex.Unlock()
                        }
                        elem.mutex.Unlock()
                }
diff --git a/send.go b/send.go
index 37ae7384b2d957b5896b6ec89041c701e0cc6fc0..3b6cfa30c4faa44130643b7bba21a8bb0d9f69e7 100644 (file)
--- a/send.go
+++ b/send.go
@@ -66,10 +66,7 @@ func (elem *QueueOutboundElement) IsDropped() bool {
        return atomic.LoadInt32(&elem.dropped) == AtomicTrue
 }
 
-func addToOutboundQueue(
-       queue chan *QueueOutboundElement,
-       element *QueueOutboundElement,
-) {
+func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
        for {
                select {
                case queue <- element:
@@ -78,32 +75,30 @@ func addToOutboundQueue(
                        select {
                        case old := <-queue:
                                old.Drop()
+                               device.PutMessageBuffer(element.buffer)
                        default:
                        }
                }
        }
 }
 
-func addToEncryptionQueue(
-       queue chan *QueueOutboundElement,
-       element *QueueOutboundElement,
-) {
-       for {
+func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
+       select {
+       case outboundQueue <- element:
                select {
-               case queue <- element:
+               case encryptionQueue <- element:
                        return
                default:
-                       select {
-                       case old := <-queue:
-                               // drop & release to potential consumer
-                               old.Drop()
-                               old.mutex.Unlock()
-                       default:
-                       }
+                       element.Drop()
+                       element.peer.device.PutMessageBuffer(element.buffer)
+                       element.mutex.Unlock()
                }
+       default:
+               element.peer.device.PutMessageBuffer(element.buffer)
        }
 }
 
+
 /* Queues a keepalive if no packets are queued for peer
  */
 func (peer *Peer) SendKeepalive() bool {
@@ -117,6 +112,7 @@ func (peer *Peer) SendKeepalive() bool {
                peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
                return true
        default:
+               peer.device.PutMessageBuffer(elem.buffer)
                return false
        }
 }
@@ -267,6 +263,7 @@ func (device *Device) RoutineReadFromTUN() {
                                logError.Println("Failed to read packet from TUN device:", err)
                                device.Close()
                        }
+                       device.PutMessageBuffer(elem.buffer)
                        return
                }
 
@@ -308,7 +305,7 @@ func (device *Device) RoutineReadFromTUN() {
                        if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
                                peer.SendHandshakeInitiation(false)
                        }
-                       addToOutboundQueue(peer.queue.nonce, elem)
+                       addToNonceQueue(peer.queue.nonce, elem, device)
                        elem = device.NewOutboundElement()
                }
        }
@@ -342,7 +339,8 @@ func (peer *Peer) RoutineNonce() {
        flush := func() {
                for {
                        select {
-                       case <-peer.queue.nonce:
+                       case elem := <-peer.queue.nonce:
+                               device.PutMessageBuffer(elem.buffer)
                        default:
                                return
                        }
@@ -402,10 +400,12 @@ func (peer *Peer) RoutineNonce() {
                                        logDebug.Println(peer, "- Obtained awaited keypair")
 
                                case <-peer.signals.flushNonceQueue:
+                                       device.PutMessageBuffer(elem.buffer)
                                        flush()
                                        goto NextPacket
 
                                case <-peer.routines.stop:
+                                       device.PutMessageBuffer(elem.buffer)
                                        return
                                }
                        }
@@ -420,6 +420,7 @@ func (peer *Peer) RoutineNonce() {
 
                        if elem.nonce >= RejectAfterMessages {
                                atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
+                               device.PutMessageBuffer(elem.buffer)
                                goto NextPacket
                        }
 
@@ -428,9 +429,7 @@ func (peer *Peer) RoutineNonce() {
                        elem.mutex.Lock()
 
                        // add to parallel and sequential queue
-
-                       addToEncryptionQueue(device.queue.encryption, elem)
-                       addToOutboundQueue(peer.queue.outbound, elem)
+                       addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
                }
        }
 }