]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use new model queues for handshakes
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 29 Jan 2021 17:24:45 +0000 (18:24 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 29 Jan 2021 17:24:45 +0000 (18:24 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/device.go
device/receive.go

index 08db24455ec7b22ad686a537bd8016e2496a541a..fd888558506f49fabf3c462dd15176460b720151 100644 (file)
@@ -13,6 +13,7 @@ import (
 
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
+
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ratelimiter"
        "golang.zx2c4.com/wireguard/rwcancel"
@@ -77,11 +78,7 @@ type Device struct {
        queue struct {
                encryption *outboundQueue
                decryption *inboundQueue
-               handshake  chan QueueHandshakeElement
-       }
-
-       signals struct {
-               stop chan struct{}
+               handshake  *handshakeQueue
        }
 
        tun struct {
@@ -90,6 +87,7 @@ type Device struct {
        }
 
        ipcMutex sync.RWMutex
+       closed   chan struct{}
 }
 
 // An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
@@ -135,6 +133,24 @@ func newInboundQueue() *inboundQueue {
        return q
 }
 
+// A handshakeQueue is similar to an outboundQueue; see those docs.
+type handshakeQueue struct {
+       c  chan QueueHandshakeElement
+       wg sync.WaitGroup
+}
+
+func newHandshakeQueue() *handshakeQueue {
+       q := &handshakeQueue{
+               c: make(chan QueueHandshakeElement, QueueHandshakeSize),
+       }
+       q.wg.Add(1)
+       go func() {
+               q.wg.Wait()
+               close(q.c)
+       }()
+       return q
+}
+
 /* Converts the peer into a "zombie", which remains in the peer map,
  * but processes no packets and does not exists in the routing table.
  *
@@ -233,7 +249,7 @@ func (device *Device) IsUnderLoad() bool {
        // check if currently under load
 
        now := time.Now()
-       underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+       underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
        if underLoad {
                device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
                return true
@@ -302,6 +318,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
 func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        device := new(Device)
+       device.closed = make(chan struct{})
        device.log = logger
        device.tun.device = tunDevice
        mtu, err := device.tun.device.MTU()
@@ -322,14 +339,10 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
 
        // create queues
 
-       device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
+       device.queue.handshake = newHandshakeQueue()
        device.queue.encryption = newOutboundQueue()
        device.queue.decryption = newInboundQueue()
 
-       // prepare signals
-
-       device.signals.stop = make(chan struct{})
-
        // prepare net
 
        device.net.port = 0
@@ -382,18 +395,6 @@ func (device *Device) RemoveAllPeers() {
        device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 }
 
-func (device *Device) FlushPacketQueues() {
-       for {
-               select {
-               case elem := <-device.queue.handshake:
-                       device.PutMessageBuffer(elem.buffer)
-               default:
-                       return
-               }
-       }
-
-}
-
 func (device *Device) Close() {
        if device.isClosed.Swap(true) {
                return
@@ -414,21 +415,20 @@ func (device *Device) Close() {
        // No new peers are coming; we are done with these queues.
        device.queue.encryption.wg.Done()
        device.queue.decryption.wg.Done()
-       close(device.signals.stop)
+       device.queue.handshake.wg.Done()
        device.state.stopping.Wait()
 
        device.RemoveAllPeers()
 
-       device.FlushPacketQueues()
-
        device.rate.limiter.Close()
 
        device.state.changing.Set(false)
        device.log.Verbosef("Interface closed")
+       close(device.closed)
 }
 
 func (device *Device) Wait() chan struct{} {
-       return device.signals.stop
+       return device.closed
 }
 
 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
@@ -561,6 +561,7 @@ func (device *Device) BindUpdate() error {
 
                device.net.stopping.Add(2)
                device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+               device.queue.handshake.wg.Add(2)  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
                go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
                go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
 
index abaf5af2a28907760c2e5b36e2b8850e640e0c9a..0b70137ec7a093810ff8bc57b2c14a92200bfe1b 100644 (file)
@@ -48,15 +48,6 @@ func (elem *QueueInboundElement) clearPointers() {
        elem.endpoint = nil
 }
 
-func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool {
-       select {
-       case queue <- elem:
-               return true
-       default:
-               return false
-       }
-}
-
 /* Called when a new authenticated message has been received
  *
  * NOTE: Not thread safe, but called by sequential receiver!
@@ -81,6 +72,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
        defer func() {
                device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
                device.queue.decryption.wg.Done()
+               device.queue.handshake.wg.Done()
                device.net.stopping.Done()
        }()
 
@@ -202,16 +194,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
                }
 
                if okay {
-                       if (device.addToHandshakeQueue(
-                               device.queue.handshake,
-                               QueueHandshakeElement{
-                                       msgType:  msgType,
-                                       buffer:   buffer,
-                                       packet:   packet,
-                                       endpoint: endpoint,
-                               },
-                       )) {
+                       select {
+                       case device.queue.handshake.c <- QueueHandshakeElement{
+                               msgType:  msgType,
+                               buffer:   buffer,
+                               packet:   packet,
+                               endpoint: endpoint,
+                       }:
                                buffer = device.GetMessageBuffer()
+                       default:
                        }
                }
        }
@@ -251,34 +242,13 @@ func (device *Device) RoutineDecryption() {
 /* Handles incoming packets related to handshake
  */
 func (device *Device) RoutineHandshake() {
-       var elem QueueHandshakeElement
-       var ok bool
-
        defer func() {
                device.log.Verbosef("Routine: handshake worker - stopped")
                device.state.stopping.Done()
-               if elem.buffer != nil {
-                       device.PutMessageBuffer(elem.buffer)
-               }
        }()
-
        device.log.Verbosef("Routine: handshake worker - started")
 
-       for {
-               if elem.buffer != nil {
-                       device.PutMessageBuffer(elem.buffer)
-                       elem.buffer = nil
-               }
-
-               select {
-               case elem, ok = <-device.queue.handshake:
-               case <-device.signals.stop:
-                       return
-               }
-
-               if !ok {
-                       return
-               }
+       for elem := range device.queue.handshake.c {
 
                // handle cookie fields and ratelimiting
 
@@ -293,7 +263,7 @@ func (device *Device) RoutineHandshake() {
                        err := binary.Read(reader, binary.LittleEndian, &reply)
                        if err != nil {
                                device.log.Verbosef("Failed to decode cookie reply")
-                               return
+                               goto skip
                        }
 
                        // lookup peer from index
@@ -301,7 +271,7 @@ func (device *Device) RoutineHandshake() {
                        entry := device.indexTable.Lookup(reply.Receiver)
 
                        if entry.peer == nil {
-                               continue
+                               goto skip
                        }
 
                        // consume reply
@@ -313,7 +283,7 @@ func (device *Device) RoutineHandshake() {
                                }
                        }
 
-                       continue
+                       goto skip
 
                case MessageInitiationType, MessageResponseType:
 
@@ -321,7 +291,7 @@ func (device *Device) RoutineHandshake() {
 
                        if !device.cookieChecker.CheckMAC1(elem.packet) {
                                device.log.Verbosef("Received packet with invalid mac1")
-                               continue
+                               goto skip
                        }
 
                        // endpoints destination address is the source of the datagram
@@ -332,19 +302,19 @@ func (device *Device) RoutineHandshake() {
 
                                if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
                                        device.SendHandshakeCookie(&elem)
-                                       continue
+                                       goto skip
                                }
 
                                // check ratelimiter
 
                                if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
-                                       continue
+                                       goto skip
                                }
                        }
 
                default:
                        device.log.Errorf("Invalid packet ended up in the handshake queue")
-                       continue
+                       goto skip
                }
 
                // handle handshake initiation/response content
@@ -359,7 +329,7 @@ func (device *Device) RoutineHandshake() {
                        err := binary.Read(reader, binary.LittleEndian, &msg)
                        if err != nil {
                                device.log.Errorf("Failed to decode initiation message")
-                               continue
+                               goto skip
                        }
 
                        // consume initiation
@@ -367,7 +337,7 @@ func (device *Device) RoutineHandshake() {
                        peer := device.ConsumeMessageInitiation(&msg)
                        if peer == nil {
                                device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
-                               continue
+                               goto skip
                        }
 
                        // update timers
@@ -392,7 +362,7 @@ func (device *Device) RoutineHandshake() {
                        err := binary.Read(reader, binary.LittleEndian, &msg)
                        if err != nil {
                                device.log.Errorf("Failed to decode response message")
-                               continue
+                               goto skip
                        }
 
                        // consume response
@@ -400,7 +370,7 @@ func (device *Device) RoutineHandshake() {
                        peer := device.ConsumeMessageResponse(&msg)
                        if peer == nil {
                                device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
-                               continue
+                               goto skip
                        }
 
                        // update endpoint
@@ -420,13 +390,15 @@ func (device *Device) RoutineHandshake() {
 
                        if err != nil {
                                device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
-                               continue
+                               goto skip
                        }
 
                        peer.timersSessionDerived()
                        peer.timersHandshakeComplete()
                        peer.SendKeepalive()
                }
+       skip:
+               device.PutMessageBuffer(elem.buffer)
        }
 }