]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Add missing locks and fix debug output, and try to flush queues
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 1 May 2018 14:59:13 +0000 (16:59 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 1 May 2018 15:46:28 +0000 (17:46 +0200)
Flushing queues on exit is sort of a partial solution, but this could be
better. Really what we want is for no more packets to be enqueued after
isUp is set to false.

device.go
peer.go
receive.go
send.go
timers.go

index 3ad53c9adfd9c155f0503287e6572af03029cfd7..dddb547dbbdf2a3cd31a54e9184e6f5c551d0b07 100644 (file)
--- a/device.go
+++ b/device.go
@@ -339,6 +339,8 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
 }
 
 func (device *Device) RemoveAllPeers() {
+       device.noise.mutex.Lock()
+       defer device.noise.mutex.Unlock()
 
        device.routing.mutex.Lock()
        defer device.routing.mutex.Unlock()
@@ -354,16 +356,25 @@ func (device *Device) RemoveAllPeers() {
 }
 
 func (device *Device) Close() {
-       device.log.Info.Println("Device closing")
        if device.isClosed.Swap(true) {
                return
        }
-       device.signal.stop.Broadcast()
+       device.log.Info.Println("Device closing")
+       device.state.changing.Set(true)
+       device.state.mutex.Lock()
+       defer device.state.mutex.Unlock()
+
        device.tun.device.Close()
        device.BindClose()
+
        device.isUp.Set(false)
+
+       device.signal.stop.Broadcast()
+
        device.RemoveAllPeers()
        device.rate.limiter.Close()
+
+       device.state.changing.Set(false)
        device.log.Info.Println("Interface closed")
 }
 
diff --git a/peer.go b/peer.go
index f10bfbb7e1d0ec6978702bc0f9223e18b7a3bc5b..ec411b2c2758ae40652954c850708b79f3b70b28 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -195,7 +195,7 @@ func (peer *Peer) Start() {
        }
 
        device := peer.device
-       device.log.Debug.Println(peer.String(), ": Starting...")
+       device.log.Debug.Println(peer.String() + ": Starting...")
 
        // sanity check : these should be 0
 
@@ -242,7 +242,7 @@ func (peer *Peer) Stop() {
        }
 
        device := peer.device
-       device.log.Debug.Println(peer.String(), ": Stopping...")
+       device.log.Debug.Println(peer.String() + ": Stopping...")
 
        // stop & wait for ongoing peer routines
 
index ca2090095ebcc3ea99c4036ffba8e53418e7e554..7d35497108c5ede6c21e21e8c65c4e733fa67128 100644 (file)
@@ -7,6 +7,7 @@ import (
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
        "net"
+       "strconv"
        "sync"
        "sync/atomic"
        "time"
@@ -101,7 +102,11 @@ func (device *Device) addToHandshakeQueue(
 func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, receive incoming, IP version:", IP)
+       defer func() {
+               logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+       }()
+
+       logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting")
 
        // receive datagrams until conn is closed
 
@@ -224,15 +229,31 @@ func (device *Device) RoutineDecryption() {
        var nonce [chacha20poly1305.NonceSize]byte
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, decryption, started for device")
+       defer func() {
+               for {
+                       select {
+                       case elem, ok := <-device.queue.decryption:
+                               if ok {
+                                       elem.Drop()
+                               }
+                       default:
+                               break
+                       }
+               }
+               logDebug.Println("Routine: decryption worker - stopped")
+       }()
+       logDebug.Println("Routine: decryption worker - started")
 
        for {
                select {
                case <-device.signal.stop.Wait():
-                       logDebug.Println("Routine, decryption worker, stopped")
                        return
 
-               case elem := <-device.queue.decryption:
+               case elem, ok := <-device.queue.decryption:
+
+                       if !ok {
+                               return
+                       }
 
                        // check if dropped
 
@@ -282,18 +303,35 @@ func (device *Device) RoutineHandshake() {
        logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
-       logDebug.Println("Routine, handshake routine, started for device")
+
+       defer func() {
+               for {
+                       select {
+                       case <-device.queue.handshake:
+                       default:
+                               return
+                       }
+               }
+               logDebug.Println("Routine: handshake worker - stopped")
+       }()
+
+       logDebug.Println("Routine: handshake worker - started")
 
        var temp [MessageHandshakeSize]byte
        var elem QueueHandshakeElement
+       var ok bool
 
        for {
                select {
-               case elem = <-device.queue.handshake:
+               case elem, ok = <-device.queue.handshake:
                case <-device.signal.stop.Wait():
                        return
                }
 
+               if !ok {
+                       return
+               }
+
                // handle cookie fields and ratelimiting
 
                switch elem.msgType {
@@ -419,7 +457,7 @@ func (device *Device) RoutineHandshake() {
                        peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
-                       logDebug.Println(peer, ": Received handshake initiation")
+                       logDebug.Println(peer.String() + ": Received handshake initiation")
 
                        // create response
 
@@ -477,7 +515,7 @@ func (device *Device) RoutineHandshake() {
                        peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
-                       logDebug.Println(peer, ": Received handshake response")
+                       logDebug.Println(peer.String() + ": Received handshake response")
 
                        peer.TimerEphemeralKeyCreated()
 
@@ -504,10 +542,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
        defer func() {
                peer.routines.stopping.Done()
-               logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Stopped")
+               logDebug.Println(peer.String() + ": Routine: sequential receiver - stopped")
        }()
 
-       logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Started")
+       logDebug.Println(peer.String() + ": Routine: sequential receiver - started")
 
        peer.routines.starting.Done()
 
diff --git a/send.go b/send.go
index df8efdbe71912395fc041266db0c1eda8245ce7e..5c6b3504af43423babf7877ae269ae100df7e2a7 100644 (file)
--- a/send.go
+++ b/send.go
@@ -121,7 +121,11 @@ func (device *Device) RoutineReadFromTUN() {
        logDebug := device.log.Debug
        logError := device.log.Error
 
-       logDebug.Println("Routine, TUN Reader started")
+       defer func() {
+               logDebug.Println("Routine: TUN reader - stopped")
+       }()
+
+       logDebug.Println("Routine: TUN reader - started")
 
        for {
 
@@ -192,11 +196,11 @@ func (peer *Peer) RoutineNonce() {
 
        defer func() {
                peer.routines.stopping.Done()
-               logDebug.Println(peer.String(), ": Routine, Nonce Worker, Stopped")
+               logDebug.Println(peer.String() + ": Routine: nonce worker - stopped")
        }()
 
        peer.routines.starting.Done()
-       logDebug.Println(peer.String(), ": Routine, Nonce Worker, Started")
+       logDebug.Println(peer.String() + ": Routine: nonce worker - started")
 
        for {
        NextPacket:
@@ -204,7 +208,11 @@ func (peer *Peer) RoutineNonce() {
                case <-peer.routines.stop.Wait():
                        return
 
-               case elem := <-peer.queue.nonce:
+               case elem, ok := <-peer.queue.nonce:
+
+                       if !ok {
+                               return
+                       }
 
                        // wait for key pair
 
@@ -218,13 +226,13 @@ func (peer *Peer) RoutineNonce() {
 
                                peer.signal.handshakeBegin.Send()
 
-                               logDebug.Println(peer.String(), ": Awaiting key-pair")
+                               logDebug.Println(peer.String() + ": Awaiting key-pair")
 
                                select {
                                case <-peer.signal.newKeyPair.Wait():
-                                       logDebug.Println(peer.String(), ": Obtained awaited key-pair")
+                                       logDebug.Println(peer.String() + ": Obtained awaited key-pair")
                                case <-peer.signal.flushNonceQueue.Wait():
-                                       logDebug.Println(peer.String(), ": Flushing nonce queue")
+                                       logDebug.Println(peer.String() + ": Flushing nonce queue")
                                        peer.FlushNonceQueue()
                                        goto NextPacket
                                case <-peer.routines.stop.Wait():
@@ -258,7 +266,22 @@ func (device *Device) RoutineEncryption() {
        var nonce [chacha20poly1305.NonceSize]byte
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, encryption worker, started")
+
+       defer func() {
+               for {
+                       select {
+                       case elem, ok := <-device.queue.encryption:
+                               if ok {
+                                       elem.Drop()
+                               }
+                       default:
+                               break
+                       }
+               }
+               logDebug.Println("Routine: encryption worker - stopped")
+       }()
+
+       logDebug.Println("Routine: encryption worker - started")
 
        for {
 
@@ -266,10 +289,13 @@ func (device *Device) RoutineEncryption() {
 
                select {
                case <-device.signal.stop.Wait():
-                       logDebug.Println("Routine, encryption worker, stopped")
                        return
 
-               case elem := <-device.queue.encryption:
+               case elem, ok := <-device.queue.encryption:
+
+                       if !ok {
+                               return
+                       }
 
                        // check if dropped
 
@@ -323,21 +349,20 @@ func (peer *Peer) RoutineSequentialSender() {
        device := peer.device
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, sequential sender, started for", peer.String())
 
        defer func() {
                peer.routines.stopping.Done()
-               logDebug.Println(peer.String(), ": Routine, Sequential sender, Stopped")
+               logDebug.Println(peer.String() + ": Routine: sequential sender - stopped")
        }()
 
+       logDebug.Println(peer.String() + ": Routine: sequential sender - started")
+
        peer.routines.starting.Done()
 
        for {
                select {
 
                case <-peer.routines.stop.Wait():
-                       logDebug.Println(
-                               "Routine, sequential sender, stopped for", peer.String())
                        return
 
                case elem, ok := <-peer.queue.outbound:
index 87255708d79a496b4b8e3e8c72de16a80a327c42..ba0d0e56f6718c1560e8bd9bc0b6b2575553a244 100644 (file)
--- a/timers.go
+++ b/timers.go
@@ -120,7 +120,7 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
  */
 func (peer *Peer) TimerHandshakeComplete() {
        peer.signal.handshakeCompleted.Send()
-       peer.device.log.Info.Println(peer.String(), ": New handshake completed")
+       peer.device.log.Info.Println(peer.String() + ": New handshake completed")
 }
 
 /* Event:
@@ -189,10 +189,12 @@ func (peer *Peer) RoutineTimerHandler() {
        logDebug := device.log.Debug
 
        defer func() {
-               logDebug.Println(peer.String(), ": Routine, Timer handler, Stopped")
+               logDebug.Println(peer.String() + ": Routine: timer handler - stopped")
                peer.routines.stopping.Done()
        }()
 
+       logDebug.Println(peer.String() + ": Routine: timer handler - started")
+
        // reset all timers
 
        peer.timer.keepalivePassive.Stop()
@@ -207,8 +209,6 @@ func (peer *Peer) RoutineTimerHandler() {
                peer.timer.keepalivePersistent.Reset(duration)
        }
 
-       logDebug.Println("Routine, timer handler, started for peer", peer.String())
-
        // signal synchronised setup complete
 
        peer.routines.starting.Done()
@@ -231,14 +231,14 @@ func (peer *Peer) RoutineTimerHandler() {
 
                        interval := peer.persistentKeepaliveInterval
                        if interval > 0 {
-                               logDebug.Println(peer.String(), ": Send keep-alive (persistent)")
+                               logDebug.Println(peer.String() + ": Send keep-alive (persistent)")
                                peer.timer.keepalivePassive.Stop()
                                peer.SendKeepAlive()
                        }
 
                case <-peer.timer.keepalivePassive.Wait():
 
-                       logDebug.Println(peer.String(), ": Send keep-alive (passive)")
+                       logDebug.Println(peer.String() + ": Send keep-alive (passive)")
 
                        peer.SendKeepAlive()
 
@@ -250,7 +250,7 @@ func (peer *Peer) RoutineTimerHandler() {
 
                case <-peer.timer.zeroAllKeys.Wait():
 
-                       logDebug.Println(peer.String(), ": Clear all key-material (timer event)")
+                       logDebug.Println(peer.String() + ": Clear all key-material (timer event)")
 
                        hs := &peer.handshake
                        hs.mutex.Lock()
@@ -283,7 +283,7 @@ func (peer *Peer) RoutineTimerHandler() {
                // handshake timers
 
                case <-peer.timer.handshakeNew.Wait():
-                       logInfo.Println(peer.String(), ": Retrying handshake (timer event)")
+                       logInfo.Println(peer.String() + ": Retrying handshake (timer event)")
                        peer.signal.handshakeBegin.Send()
 
                case <-peer.timer.handshakeTimeout.Wait():
@@ -301,16 +301,16 @@ func (peer *Peer) RoutineTimerHandler() {
                        err := peer.sendNewHandshake()
 
                        if err != nil {
-                               logInfo.Println(peer.String()": Failed to send handshake initiation", err)
+                               logInfo.Println(peer.String()+": Failed to send handshake initiation", err)
                        } else {
-                               logDebug.Println(peer.String(), ": Send handshake initiation (subsequent)")
+                               logDebug.Println(peer.String() + ": Send handshake initiation (subsequent)")
                        }
 
                case <-peer.timer.handshakeDeadline.Wait():
 
                        // clear all queued packets and stop keep-alive
 
-                       logInfo.Println(peer.String(), ": Handshake negotiation timed-out")
+                       logInfo.Println(peer.String() + ": Handshake negotiation timed-out")
 
                        peer.signal.flushNonceQueue.Send()
                        peer.timer.keepalivePersistent.Stop()
@@ -325,16 +325,16 @@ func (peer *Peer) RoutineTimerHandler() {
                        err := peer.sendNewHandshake()
 
                        if err != nil {
-                               logInfo.Println(peer.String()": Failed to send handshake initiation", err)
+                               logInfo.Println(peer.String()+": Failed to send handshake initiation", err)
                        } else {
-                               logDebug.Println(peer.String(), ": Send handshake initiation (initial)")
+                               logDebug.Println(peer.String() + ": Send handshake initiation (initial)")
                        }
 
                        peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
 
                case <-peer.signal.handshakeCompleted.Wait():
 
-                       logInfo.Println(peer.String(), ": Handshake completed")
+                       logInfo.Println(peer.String() + ": Handshake completed")
 
                        atomic.StoreInt64(
                                &peer.stats.lastHandshakeNano,