]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use channel close to shut down and drain decryption channel
authorJosh Bleecher Snyder <josh@tailscale.com>
Tue, 12 Jan 2021 01:34:02 +0000 (17:34 -0800)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 20 Jan 2021 18:56:54 +0000 (19:56 +0100)
This is similar to commit e1fa1cc5560020e67d33aa7e74674853671cf0a0,
but for the decryption channel.

It is an alternative fix to f9f655567930a4cd78d40fa4ba0d58503335ae6a.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
device/device.go
device/receive.go

index d37fe6f05a4016d7a212f7f934b3c743ca946309..9a9b1b3575f9ed864134d822bf3567a124688631 100644 (file)
@@ -76,7 +76,7 @@ type Device struct {
 
        queue struct {
                encryption *encryptionQueue
-               decryption chan *QueueInboundElement
+               decryption *decryptionQueue
                handshake  chan QueueHandshakeElement
        }
 
@@ -115,6 +115,24 @@ func newEncryptionQueue() *encryptionQueue {
        return q
 }
 
+// A decryptionQueue is similar to an encryptionQueue; see those docs.
+type decryptionQueue struct {
+       c  chan *QueueInboundElement
+       wg sync.WaitGroup
+}
+
+func newDecryptionQueue() *decryptionQueue {
+       q := &decryptionQueue{
+               c: make(chan *QueueInboundElement, QueueInboundSize),
+       }
+       q.wg.Add(1)
+       go func() {
+               q.wg.Wait()
+               close(q.c)
+       }()
+       return q
+}
+
 /* Converts the peer into a "zombie", which remains in the peer map,
  * but processes no packets and does not exists in the routing table.
  *
@@ -308,7 +326,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
 
        device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
        device.queue.encryption = newEncryptionQueue()
-       device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
+       device.queue.decryption = newDecryptionQueue()
 
        // prepare signals
 
@@ -369,13 +387,6 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) FlushPacketQueues() {
        for {
                select {
-               case elem, ok := <-device.queue.decryption:
-                       if ok {
-                               if !elem.IsDropped() {
-                                       elem.Drop()
-                                       device.PutMessageBuffer(elem.buffer)
-                               }
-                       }
                case <-device.queue.handshake:
                default:
                        return
@@ -399,10 +410,11 @@ func (device *Device) Close() {
 
        device.isUp.Set(false)
 
-       // We kept a reference to the encryption queue,
-       // in case we started any new peers that might write to it.
-       // No new peers are coming; we are done with the encryption queue.
+       // We kept a reference to the encryption and decryption queues,
+       // in case we started any new peers that might write to them.
+       // No new peers are coming; we are done with these queues.
        device.queue.encryption.wg.Done()
+       device.queue.decryption.wg.Done()
        close(device.signals.stop)
        device.state.stopping.Wait()
 
@@ -549,6 +561,7 @@ func (device *Device) BindUpdate() error {
                // start receiving routines
 
                device.net.stopping.Add(2)
+               device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
                go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
                go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
 
index fa31a1aa927f623abc496b4ad903d972ff0e413b..20e0c8fae8d8779cecd2215eba8b0257032fc8a4 100644 (file)
@@ -109,6 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
        logDebug := device.log.Debug
        defer func() {
                logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+               device.queue.decryption.wg.Done()
                device.net.stopping.Done()
        }()
 
@@ -206,7 +207,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
 
                        peer.queue.RLock()
                        if peer.isRunning.Get() {
-                               if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
+                               if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption.c, elem) {
                                        buffer = device.GetMessageBuffer()
                                }
                        } else {
@@ -258,59 +259,35 @@ func (device *Device) RoutineDecryption() {
        }()
        logDebug.Println("Routine: decryption worker - started")
 
-       for {
-               select {
-               case <-device.signals.stop:
-                       for {
-                               select {
-                               case elem, ok := <-device.queue.decryption:
-                                       if ok {
-                                               if !elem.IsDropped() {
-                                                       elem.Drop()
-                                                       device.PutMessageBuffer(elem.buffer)
-                                               }
-                                               elem.Unlock()
-                                       }
-                               default:
-                                       return
-                               }
-                       }
-
-               case elem, ok := <-device.queue.decryption:
+       for elem := range device.queue.decryption.c {
+               // check if dropped
 
-                       if !ok {
-                               return
-                       }
-
-                       // check if dropped
-
-                       if elem.IsDropped() {
-                               continue
-                       }
+               if elem.IsDropped() {
+                       continue
+               }
 
-                       // split message into fields
+               // split message into fields
 
-                       counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
-                       content := elem.packet[MessageTransportOffsetContent:]
+               counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+               content := elem.packet[MessageTransportOffsetContent:]
 
-                       // decrypt and release to consumer
+               // decrypt and release to consumer
 
-                       var err error
-                       elem.counter = binary.LittleEndian.Uint64(counter)
-                       // copy counter to nonce
-                       binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
-                       elem.packet, err = elem.keypair.receive.Open(
-                               content[:0],
-                               nonce[:],
-                               content,
-                               nil,
-                       )
-                       if err != nil {
-                               elem.Drop()
-                               device.PutMessageBuffer(elem.buffer)
-                       }
-                       elem.Unlock()
+               var err error
+               elem.counter = binary.LittleEndian.Uint64(counter)
+               // copy counter to nonce
+               binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+               elem.packet, err = elem.keypair.receive.Open(
+                       content[:0],
+                       nonce[:],
+                       content,
+                       nil,
+               )
+               if err != nil {
+                       elem.Drop()
+                       device.PutMessageBuffer(elem.buffer)
                }
+               elem.Unlock()
        }
 }