]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Work on timer teardown + bug fixes
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 13 Jan 2018 08:00:37 +0000 (09:00 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 13 Jan 2018 08:00:37 +0000 (09:00 +0100)
Added waitgroups to peer struct for routine
start / stop synchronisation

src/conn.go
src/device.go
src/peer.go
src/receive.go
src/send.go
src/timers.go
src/tun.go
src/uapi.go

index ddb7ed1b3227686468dbfa3518f969dcc653b679..1d033ff07167e48b70e1ef5b580b842939496c78 100644 (file)
@@ -64,13 +64,9 @@ func unsafeCloseBind(device *Device) error {
        return err
 }
 
-func updateBind(device *Device) error {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
-
-       netc := &device.net
-       netc.mutex.Lock()
-       defer netc.mutex.Unlock()
+/* Must hold device and net lock
+ */
+func unsafeUpdateBind(device *Device) error {
 
        // close existing sockets
 
@@ -89,6 +85,7 @@ func updateBind(device *Device) error {
                // bind to new port
 
                var err error
+               netc := &device.net
                netc.bind, netc.port, err = CreateBind(netc.port)
                if err != nil {
                        netc.bind = nil
index f4a087c3c053c876eb76ffa17de240508a1fe4ba..5f8e91bfa32b2e309dd20adf050422f12401d642 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "github.com/sasha-s/go-deadlock"
        "runtime"
        "sync"
        "sync/atomic"
@@ -21,12 +22,12 @@ type Device struct {
                messageBuffers sync.Pool
        }
        net struct {
-               mutex  sync.RWMutex
+               mutex  deadlock.RWMutex
                bind   Bind   // bind interface
                port   uint16 // listening port
                fwmark uint32 // mark value (0 = disabled)
        }
-       mutex        sync.RWMutex
+       mutex        deadlock.RWMutex
        privateKey   NoisePrivateKey
        publicKey    NoisePublicKey
        routingTable RoutingTable
@@ -49,8 +50,15 @@ func (device *Device) Up() {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       device.isUp.Set(true)
-       updateBind(device)
+       device.net.mutex.Lock()
+       defer device.net.mutex.Unlock()
+
+       if device.isUp.Swap(true) {
+               return
+       }
+
+       unsafeUpdateBind(device)
+
        for _, peer := range device.peers {
                peer.Start()
        }
@@ -60,8 +68,12 @@ func (device *Device) Down() {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       device.isUp.Set(false)
+       if !device.isUp.Swap(false) {
+               return
+       }
+
        closeBind(device)
+
        for _, peer := range device.peers {
                peer.Stop()
        }
@@ -75,7 +87,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
        if !ok {
                return
        }
-       peer.mutex.Lock()
        peer.Stop()
        device.routingTable.RemovePeer(peer)
        delete(device.peers, key)
index 7c6ad476697ad21bb7d5a20c39ed21a119128116..3d829897591abf149f4eb31f27ac7c37a67a2cb0 100644 (file)
@@ -8,6 +8,10 @@ import (
        "time"
 )
 
+const (
+       PeerRoutineNumber = 4
+)
+
 type Peer struct {
        id                          uint
        mutex                       sync.RWMutex
@@ -34,7 +38,6 @@ type Peer struct {
                flushNonceQueue    Signal // size 1, empty queued packets
                messageSend        Signal // size 1, message was send to peer
                messageReceived    Signal // size 1, authenticated message recv
-               stop               Signal // size 0, stop all goroutines in peer
        }
        timer struct {
                // state related to WireGuard timers
@@ -54,6 +57,12 @@ type Peer struct {
                outbound chan *QueueOutboundElement // sequential ordering of work
                inbound  chan *QueueInboundElement  // sequential ordering of work
        }
+       routines struct {
+               mutex    sync.Mutex     // held when stopping / starting routines
+               starting sync.WaitGroup // routines pending start
+               stopping sync.WaitGroup // routines pending stop
+               stop     Signal         // size 0, stop all goroutines in peer
+       }
        mac CookieGenerator
 }
 
@@ -121,6 +130,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        peer.signal.handshakeCompleted = NewSignal()
        peer.signal.flushNonceQueue = NewSignal()
 
+       peer.routines.mutex.Lock()
+       peer.routines.stop = NewSignal()
+       peer.routines.mutex.Unlock()
+
        return peer, nil
 }
 
@@ -156,32 +169,43 @@ func (peer *Peer) String() string {
        )
 }
 
-/* Starts all routines for a given peer
- *
- * Requires that the caller holds the exclusive peer lock!
- */
-func unsafePeerStart(peer *Peer) {
-       peer.signal.stop.Broadcast()
-       peer.signal.stop = NewSignal()
+func (peer *Peer) Start() {
+
+       peer.routines.mutex.Lock()
+       defer peer.routines.mutex.Lock()
+
+       // stop & wait for ungoing routines (if any)
+
+       peer.routines.stop.Broadcast()
+       peer.routines.starting.Wait()
+       peer.routines.stopping.Wait()
 
-       var wait sync.WaitGroup
+       // reset signal and start (new) routines
 
-       wait.Add(1)
+       peer.routines.stop = NewSignal()
+       peer.routines.starting.Add(PeerRoutineNumber)
+       peer.routines.stopping.Add(PeerRoutineNumber)
 
        go peer.RoutineNonce()
-       go peer.RoutineTimerHandler(&wait)
+       go peer.RoutineTimerHandler()
        go peer.RoutineSequentialSender()
        go peer.RoutineSequentialReceiver()
 
-       wait.Wait()
-}
-
-func (peer *Peer) Start() {
-       peer.mutex.Lock()
-       unsafePeerStart(peer)
-       peer.mutex.Unlock()
+       peer.routines.starting.Wait()
 }
 
 func (peer *Peer) Stop() {
-       peer.signal.stop.Broadcast()
+
+       peer.routines.mutex.Lock()
+       defer peer.routines.mutex.Lock()
+
+       // stop & wait for ungoing routines (if any)
+
+       peer.routines.stop.Broadcast()
+       peer.routines.starting.Wait()
+       peer.routines.stopping.Wait()
+
+       // reset signal (to handle repeated stopping)
+
+       peer.routines.stop = NewSignal()
 }
index dbd2813ea10e88131325eb86225d3ad6c8c07518..e6e8481001fc35ed66bc8e601b2eb17a1d06a90a 100644 (file)
@@ -497,7 +497,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                select {
 
-               case <-peer.signal.stop.Wait():
+               case <-peer.routines.stop.Wait():
                        logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
                        return
 
index 9537f5eb9113216781ff654598021ea611ab6ed2..fa13c9181accc5cbdac48830e469d9c29ba09bed 100644 (file)
@@ -192,7 +192,7 @@ func (peer *Peer) RoutineNonce() {
        for {
        NextPacket:
                select {
-               case <-peer.signal.stop.Wait():
+               case <-peer.routines.stop.Wait():
                        return
 
                case elem := <-peer.queue.nonce:
@@ -217,7 +217,7 @@ func (peer *Peer) RoutineNonce() {
                                        logDebug.Println("Clearing queue for", peer.String())
                                        peer.FlushNonceQueue()
                                        goto NextPacket
-                               case <-peer.signal.stop.Wait():
+                               case <-peer.routines.stop.Wait():
                                        return
                                }
                        }
@@ -309,15 +309,20 @@ func (device *Device) RoutineEncryption() {
  * The routine terminates then the outbound queue is closed.
  */
 func (peer *Peer) RoutineSequentialSender() {
+
+       defer peer.routines.stopping.Done()
+
        device := peer.device
 
        logDebug := device.log.Debug
        logDebug.Println("Routine, sequential sender, started for", peer.String())
 
+       peer.routines.starting.Done()
+
        for {
                select {
 
-               case <-peer.signal.stop.Wait():
+               case <-peer.routines.stop.Wait():
                        logDebug.Println(
                                "Routine, sequential sender, stopped for", peer.String())
                        return
index f2fed30ec28e101833a686b6b41078312114867d..f1ed9c5bd378d08c22d8d72fb1f541fcac0629d4 100644 (file)
@@ -4,7 +4,6 @@ import (
        "bytes"\r
        "encoding/binary"\r
        "math/rand"\r
-       "sync"\r
        "sync/atomic"\r
        "time"\r
 )\r
@@ -182,7 +181,10 @@ func (peer *Peer) sendNewHandshake() error {
        return err\r
 }\r
 \r
-func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {\r
+func (peer *Peer) RoutineTimerHandler() {\r
+\r
+       defer peer.routines.stopping.Done()\r
+\r
        device := peer.device\r
 \r
        logInfo := device.log.Info\r
@@ -203,15 +205,20 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
                peer.timer.keepalivePersistent.Reset(duration)\r
        }\r
 \r
-       // signal that timers are reset\r
+       // signal synchronised setup complete\r
 \r
-       ready.Done()\r
+       peer.routines.starting.Done()\r
 \r
        // handle timer events\r
 \r
        for {\r
                select {\r
 \r
+               /* stopping */\r
+\r
+               case <-peer.routines.stop.Wait():\r
+                       return\r
+\r
                /* timers */\r
 \r
                // keep-alive\r
@@ -312,9 +319,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
 \r
                /* signals */\r
 \r
-               case <-peer.signal.stop.Wait():\r
-                       return\r
-\r
                case <-peer.signal.handshakeBegin.Wait():\r
 \r
                        peer.signal.handshakeBegin.Disable()\r
index 024f0f07133bcadaddb811ec6bca4932ff5aaaad..6259f33a49c97f58798b5dda1b312238c708e317 100644 (file)
@@ -45,14 +45,14 @@ func (device *Device) RoutineTUNEventReader() {
                        }
                }
 
-               if event&TUNEventUp != 0 {
+               if event&TUNEventUp != 0 && !device.isUp.Get() {
                        logInfo.Println("Interface set up")
                        device.Up()
                }
 
-               if event&TUNEventDown != 0 {
+               if event&TUNEventDown != 0 && device.isUp.Get() {
                        logInfo.Println("Interface set down")
-                       device.Up()
+                       device.Down()
                }
        }
 }
index a67bff1424578ca11d3f07ec0790568643808b1f..f66528c7dbe6470743e76ef8cc6c753b2e29023b 100644 (file)
@@ -133,13 +133,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                device.SetPrivateKey(sk)
 
                        case "listen_port":
+
+                               // parse port number
+
                                port, err := strconv.ParseUint(value, 10, 16)
                                if err != nil {
                                        logError.Println("Failed to parse listen_port:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
+
+                               // update port and rebind
+
+                               device.mutex.Lock()
+                               device.net.mutex.Lock()
+
                                device.net.port = uint16(port)
-                               if err := updateBind(device); err != nil {
+                               err = unsafeUpdateBind(device)
+
+                               device.net.mutex.Unlock()
+                               device.mutex.Unlock()
+
+                               if err != nil {
                                        logError.Println("Failed to set listen_port:", err)
                                        return &IPCError{Code: ipcErrorPortInUse}
                                }