]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Fix shutdown races
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 23 Sep 2018 23:52:02 +0000 (01:52 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 23 Sep 2018 23:52:02 +0000 (01:52 +0200)
device.go
receive.go
send.go

index 7cf9ba2253e12b6e94d4d622ef6ee9dbcc9c1f6e..8823404de819e70938e089d51ed5cb102399a427 100644 (file)
--- a/device.go
+++ b/device.go
@@ -377,10 +377,11 @@ func (device *Device) Close() {
 
        close(device.signals.stop)
 
+       device.RemoveAllPeers()
+
        device.state.stopping.Wait()
        device.FlushPacketQueues()
 
-       device.RemoveAllPeers()
        device.rate.limiter.Close()
 
        device.state.changing.Set(false)
index ab8691395fb0e816f654ed38cca779214ffe0883..01151cadd2376b2efe2c76e181467869d87dfda5 100644 (file)
@@ -247,7 +247,6 @@ func (device *Device) RoutineDecryption() {
                        // check if dropped
 
                        if elem.IsDropped() {
-                               device.PutInboundElement(elem)
                                continue
                        }
 
@@ -281,7 +280,6 @@ func (device *Device) RoutineDecryption() {
                        if err != nil {
                                elem.Drop()
                                device.PutMessageBuffer(elem.buffer)
-                               elem.buffer = nil
                        }
                        elem.mutex.Unlock()
                }
@@ -313,6 +311,7 @@ func (device *Device) RoutineHandshake() {
        for {
                if elem.buffer != nil {
                        device.PutMessageBuffer(elem.buffer)
+                       elem.buffer = nil
                }
 
                select {
@@ -494,7 +493,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                logDebug.Println(peer, "- Routine: sequential receiver - stopped")
                peer.routines.stopping.Done()
                if elem != nil {
-                       if elem.buffer != nil {
+                       if !elem.IsDropped() {
                                device.PutMessageBuffer(elem.buffer)
                        }
                        device.PutInboundElement(elem)
@@ -507,10 +506,11 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
        for {
                if elem != nil {
-                       if elem.buffer != nil {
+                       if !elem.IsDropped() {
                                device.PutMessageBuffer(elem.buffer)
                        }
                        device.PutInboundElement(elem)
+                       elem = nil
                }
 
                select {
diff --git a/send.go b/send.go
index fa8404347aa341c5e5829c19f8660cde40558563..b636a437dfda15aea4f79bb00aaf0f544aac7be6 100644 (file)
--- a/send.go
+++ b/send.go
@@ -341,12 +341,6 @@ func (peer *Peer) RoutineNonce() {
        device := peer.device
        logDebug := device.log.Debug
 
-       defer func() {
-               logDebug.Println(peer, "- Routine: nonce worker - stopped")
-               peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
-               peer.routines.stopping.Done()
-       }()
-
        flush := func() {
                for {
                        select {
@@ -359,6 +353,13 @@ func (peer *Peer) RoutineNonce() {
                }
        }
 
+       defer func() {
+               flush()
+               logDebug.Println(peer, "- Routine: nonce worker - stopped")
+               peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
+               peer.routines.stopping.Done()
+       }()
+
        peer.routines.starting.Done()
        logDebug.Println(peer, "- Routine: nonce worker - started")
 
@@ -461,6 +462,19 @@ func (device *Device) RoutineEncryption() {
        logDebug := device.log.Debug
 
        defer func() {
+               for {
+                       select {
+                       case elem, ok := <-device.queue.encryption:
+                               if ok && !elem.IsDropped() {
+                                       elem.Drop()
+                                       device.PutMessageBuffer(elem.buffer)
+                                       elem.mutex.Unlock()
+                               }
+                       default:
+                               goto out
+                       }
+               }
+       out:
                logDebug.Println("Routine: encryption worker - stopped")
                device.state.stopping.Done()
        }()
@@ -485,7 +499,6 @@ func (device *Device) RoutineEncryption() {
                        // check if dropped
 
                        if elem.IsDropped() {
-                               device.PutOutboundElement(elem)
                                continue
                        }
 
@@ -540,6 +553,22 @@ func (peer *Peer) RoutineSequentialSender() {
        logError := device.log.Error
 
        defer func() {
+               for {
+                       select {
+                       case elem, ok := <-peer.queue.outbound:
+                               if ok {
+                                       if !elem.IsDropped() {
+                                               device.PutMessageBuffer(elem.buffer)
+                                               elem.Drop()
+                                       }
+                                       device.PutOutboundElement(elem)
+                                       elem.mutex.Unlock()
+                               }
+                       default:
+                               goto out
+                       }
+               }
+       out:
                logDebug.Println(peer, "- Routine: sequential sender - stopped")
                peer.routines.stopping.Done()
        }()