]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added initial version of peer teardown
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 26 Jan 2018 21:52:32 +0000 (22:52 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 26 Jan 2018 21:52:32 +0000 (22:52 +0100)
There is a double lock issue with device.Close which has yet to be
resolved.

src/conn.go
src/device.go
src/peer.go
src/receive.go
src/send.go
src/uapi.go

index 1d033ff07167e48b70e1ef5b580b842939496c78..c2f5deeb035882db514c72d4a8d2137798620611 100644 (file)
@@ -64,9 +64,13 @@ func unsafeCloseBind(device *Device) error {
        return err
 }
 
-/* Must hold device and net lock
- */
-func unsafeUpdateBind(device *Device) error {
+func (device *Device) BindUpdate() error {
+       device.mutex.Lock()
+       defer device.mutex.Unlock()
+
+       netc := &device.net
+       netc.mutex.Lock()
+       defer netc.mutex.Unlock()
 
        // close existing sockets
 
@@ -74,18 +78,13 @@ func unsafeUpdateBind(device *Device) error {
                return err
        }
 
-       // assumption: netc.update WaitGroup should be exactly 1
-
        // open new sockets
 
        if device.isUp.Get() {
 
-               device.log.Debug.Println("UDP bind updating")
-
                // bind to new port
 
                var err error
-               netc := &device.net
                netc.bind, netc.port, err = CreateBind(netc.port)
                if err != nil {
                        netc.bind = nil
@@ -109,7 +108,7 @@ func unsafeUpdateBind(device *Device) error {
                        peer.mutex.Unlock()
                }
 
-               // decrease waitgroup to 0
+               // start receiving routines
 
                go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
                go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
@@ -120,7 +119,7 @@ func unsafeUpdateBind(device *Device) error {
        return nil
 }
 
-func closeBind(device *Device) error {
+func (device *Device) BindClose() error {
        device.mutex.Lock()
        device.net.mutex.Lock()
        err := unsafeCloseBind(device)
index 5f8e91bfa32b2e309dd20adf050422f12401d642..f1c09c6bf8c694ee641d8be01d1884295973e4b2 100644 (file)
@@ -9,7 +9,7 @@ import (
 )
 
 type Device struct {
-       isUp      AtomicBool // device is up (TUN interface up)?
+       isUp      AtomicBool // device is (going) up
        isClosed  AtomicBool // device is closed? (acting as guard)
        log       *Logger    // collection of loggers for levels
        idCounter uint       // for assigning debug ids to peers
@@ -18,6 +18,11 @@ type Device struct {
                device TUNDevice
                mtu    int32
        }
+       state struct {
+               mutex    deadlock.Mutex
+               changing AtomicBool
+               current  bool
+       }
        pool struct {
                messageBuffers sync.Pool
        }
@@ -46,37 +51,86 @@ type Device struct {
        mac            CookieChecker
 }
 
-func (device *Device) Up() {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
+func deviceUpdateState(device *Device) {
 
-       device.net.mutex.Lock()
-       defer device.net.mutex.Unlock()
+       // check if state already being updated (guard)
 
-       if device.isUp.Swap(true) {
+       if device.state.changing.Swap(true) {
                return
        }
 
-       unsafeUpdateBind(device)
+       // compare to current state of device
+
+       device.state.mutex.Lock()
+
+       newIsUp := device.isUp.Get()
+
+       if newIsUp == device.state.current {
+               device.state.mutex.Unlock()
+               device.state.changing.Set(false)
+               return
+       }
+
+       device.state.mutex.Unlock()
+
+       // change state of device
+
+       switch newIsUp {
+       case true:
+
+               // start listener
+
+               if err := device.BindUpdate(); err != nil {
+                       device.isUp.Set(false)
+                       break
+               }
+
+               // start every peer
+
+               for _, peer := range device.peers {
+                       peer.Start()
+               }
+
+       case false:
+
+               // stop listening
+
+               device.BindClose()
 
-       for _, peer := range device.peers {
-               peer.Start()
+               // stop every peer
+
+               for _, peer := range device.peers {
+                       peer.Stop()
+               }
        }
+
+       // update state variables
+       // and check for state change in the mean time
+
+       device.state.current = newIsUp
+       device.state.changing.Set(false)
+       deviceUpdateState(device)
 }
 
-func (device *Device) Down() {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
+func (device *Device) Up() {
+
+       // closed device cannot be brought up
 
-       if !device.isUp.Swap(false) {
+       if device.isClosed.Get() {
                return
        }
 
-       closeBind(device)
+       device.state.mutex.Lock()
+       device.isUp.Set(true)
+       device.state.mutex.Unlock()
+       deviceUpdateState(device)
+}
 
-       for _, peer := range device.peers {
-               peer.Stop()
-       }
+func (device *Device) Down() {
+       device.state.mutex.Lock()
+       device.isUp.Set(false)
+       device.state.mutex.Unlock()
+       deviceUpdateState(device)
 }
 
 /* Warning:
@@ -87,7 +141,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
        if !ok {
                return
        }
-       peer.Stop()
        device.routingTable.RemovePeer(peer)
        delete(device.peers, key)
 }
@@ -231,20 +284,30 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
 func (device *Device) RemoveAllPeers() {
        device.mutex.Lock()
        defer device.mutex.Unlock()
-       for key := range device.peers {
-               removePeerUnsafe(device, key)
+
+       for key, peer := range device.peers {
+               peer.Stop()
+               peer, ok := device.peers[key]
+               if !ok {
+                       return
+               }
+               device.routingTable.RemovePeer(peer)
+               delete(device.peers, key)
        }
 }
 
 func (device *Device) Close() {
+       device.log.Info.Println("Device closing")
        if device.isClosed.Swap(true) {
                return
        }
-       device.log.Info.Println("Closing device")
-       device.RemoveAllPeers()
        device.signal.stop.Broadcast()
        device.tun.device.Close()
-       closeBind(device)
+       device.BindClose()
+       device.isUp.Set(false)
+       println("remove")
+       device.RemoveAllPeers()
+       device.log.Info.Println("Interface closed")
 }
 
 func (device *Device) Wait() chan struct{} {
index 3d829897591abf149f4eb31f27ac7c37a67a2cb0..5ad45118f6d0d9db717ceff189edb675612765c2 100644 (file)
@@ -4,6 +4,7 @@ import (
        "encoding/base64"
        "errors"
        "fmt"
+       "github.com/sasha-s/go-deadlock"
        "sync"
        "time"
 )
@@ -14,7 +15,8 @@ const (
 
 type Peer struct {
        id                          uint
-       mutex                       sync.RWMutex
+       isRunning                   AtomicBool
+       mutex                       deadlock.RWMutex
        persistentKeepaliveInterval uint64
        keyPairs                    KeyPairs
        handshake                   Handshake
@@ -26,7 +28,7 @@ type Peer struct {
                lastHandshakeNano int64  // nano seconds since epoch
        }
        time struct {
-               mutex         sync.RWMutex
+               mutex         deadlock.RWMutex
                lastSend      time.Time // last send message
                lastHandshake time.Time // last completed handshake
                nextKeepalive time.Time
@@ -58,7 +60,7 @@ type Peer struct {
                inbound  chan *QueueInboundElement  // sequential ordering of work
        }
        routines struct {
-               mutex    sync.Mutex     // held when stopping / starting routines
+               mutex    deadlock.Mutex // held when stopping / starting routines
                starting sync.WaitGroup // routines pending start
                stopping sync.WaitGroup // routines pending stop
                stop     Signal         // size 0, stop all goroutines in peer
@@ -67,6 +69,14 @@ type Peer struct {
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+
+       if device.isClosed.Get() {
+               return nil, errors.New("Device closed")
+       }
+
+       device.mutex.Lock()
+       defer device.mutex.Unlock()
+
        // create peer
 
        peer := new(Peer)
@@ -75,17 +85,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        peer.mac.Init(pk)
        peer.device = device
+       peer.isRunning.Set(false)
 
+       peer.timer.zeroAllKeys = NewTimer()
        peer.timer.keepalivePersistent = NewTimer()
        peer.timer.keepalivePassive = NewTimer()
-       peer.timer.zeroAllKeys = NewTimer()
        peer.timer.handshakeNew = NewTimer()
        peer.timer.handshakeDeadline = NewTimer()
        peer.timer.handshakeTimeout = NewTimer()
 
        // assign id for debugging
 
-       device.mutex.Lock()
        peer.id = device.idCounter
        device.idCounter += 1
 
@@ -102,7 +112,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
                return nil, errors.New("Adding existing peer")
        }
        device.peers[pk] = peer
-       device.mutex.Unlock()
 
        // precompute DH
 
@@ -117,23 +126,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        peer.endpoint = nil
 
-       // prepare queuing
-
-       peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
-       peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
-       peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
-
        // prepare signaling & routines
 
-       peer.signal.newKeyPair = NewSignal()
-       peer.signal.handshakeBegin = NewSignal()
-       peer.signal.handshakeCompleted = NewSignal()
-       peer.signal.flushNonceQueue = NewSignal()
-
        peer.routines.mutex.Lock()
        peer.routines.stop = NewSignal()
        peer.routines.mutex.Unlock()
 
+       // start peer
+
+       peer.device.state.mutex.Lock()
+       if peer.device.isUp.Get() {
+               peer.Start()
+       }
+       peer.device.state.mutex.Unlock()
+
        return peer, nil
 }
 
@@ -148,6 +154,10 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
                return errors.New("No known endpoint for peer")
        }
 
+       if peer.device.net.bind == nil {
+               return errors.New("No bind")
+       }
+
        return peer.device.net.bind.Send(buffer, peer.endpoint)
 }
 
@@ -174,12 +184,26 @@ func (peer *Peer) Start() {
        peer.routines.mutex.Lock()
        defer peer.routines.mutex.Lock()
 
+       peer.device.log.Debug.Println("Starting:", peer.String())
+
        // stop & wait for ungoing routines (if any)
 
+       peer.isRunning.Set(false)
        peer.routines.stop.Broadcast()
        peer.routines.starting.Wait()
        peer.routines.stopping.Wait()
 
+       // prepare queues
+
+       peer.signal.newKeyPair = NewSignal()
+       peer.signal.handshakeBegin = NewSignal()
+       peer.signal.handshakeCompleted = NewSignal()
+       peer.signal.flushNonceQueue = NewSignal()
+
+       peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
+       peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+       peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+
        // reset signal and start (new) routines
 
        peer.routines.stop = NewSignal()
@@ -192,6 +216,7 @@ func (peer *Peer) Start() {
        go peer.RoutineSequentialReceiver()
 
        peer.routines.starting.Wait()
+       peer.isRunning.Set(true)
 }
 
 func (peer *Peer) Stop() {
@@ -199,13 +224,22 @@ func (peer *Peer) Stop() {
        peer.routines.mutex.Lock()
        defer peer.routines.mutex.Lock()
 
+       peer.device.log.Debug.Println("Stopping:", peer.String())
+
        // stop & wait for ungoing routines (if any)
 
        peer.routines.stop.Broadcast()
        peer.routines.starting.Wait()
        peer.routines.stopping.Wait()
 
+       // close queues
+
+       close(peer.queue.nonce)
+       close(peer.queue.outbound)
+       close(peer.queue.inbound)
+
        // reset signal (to handle repeated stopping)
 
        peer.routines.stop = NewSignal()
+       peer.isRunning.Set(false)
 }
index 0b87a3cd4d1fc5588e8a6cde38900cb4bc7fe8ff..5ad7c4b86412e79561d1177e6d4839c7da645134 100644 (file)
@@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                case ipv6.Version:
                        size, endpoint, err = bind.ReceiveIPv6(buffer[:])
                default:
-                       return
+                       panic("invalid IP version")
                }
 
                if err != nil {
@@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
 
                        // add to decryption queues
 
-                       device.addToDecryptionQueue(device.queue.decryption, elem)
-                       device.addToInboundQueue(peer.queue.inbound, elem)
-                       buffer = device.GetMessageBuffer()
+                       if peer.isRunning.Get() {
+                               device.addToDecryptionQueue(device.queue.decryption, elem)
+                               device.addToInboundQueue(peer.queue.inbound, elem)
+                               buffer = device.GetMessageBuffer()
+                       }
 
                        continue
 
@@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() {
                                return
                        }
 
-                       // lookup peer and consume response
+                       // lookup peer from index
 
                        entry := device.indices.Lookup(reply.Receiver)
+
                        if entry.peer == nil {
                                continue
                        }
-                       entry.peer.mac.ConsumeReply(&reply)
+
+                       // consume reply
+
+                       if peer := entry.peer; peer.isRunning.Get() {
+                               peer.mac.ConsumeReply(&reply)
+                       }
+
                        continue
 
                case MessageInitiationType, MessageResponseType:
index fa13c9181accc5cbdac48830e469d9c29ba09bed..e0a546d8b46d308fcc837d54865677917e0b2b33 100644 (file)
@@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() {
 
                // insert into nonce/pre-handshake queue
 
-               peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
-               addToOutboundQueue(peer.queue.nonce, elem)
-               elem = device.NewOutboundElement()
+               if peer.isRunning.Get() {
+                       peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+                       addToOutboundQueue(peer.queue.nonce, elem)
+                       elem = device.NewOutboundElement()
+               }
        }
 }
 
index f66528c7dbe6470743e76ef8cc6c753b2e29023b..68ebe43111c46cd6763c9ee579e90f6bd09e3580 100644 (file)
@@ -144,16 +144,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // update port and rebind
 
-                               device.mutex.Lock()
                                device.net.mutex.Lock()
-
                                device.net.port = uint16(port)
-                               err = unsafeUpdateBind(device)
-
                                device.net.mutex.Unlock()
-                               device.mutex.Unlock()
 
-                               if err != nil {
+                               if err := device.BindUpdate(); err != nil {
                                        logError.Println("Failed to set listen_port:", err)
                                        return &IPCError{Code: ipcErrorPortInUse}
                                }
@@ -179,6 +174,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                device.net.fwmark = uint32(fwmark)
                                device.net.mutex.Unlock()
 
+                               if err := device.BindUpdate(); err != nil {
+                                       logError.Println("Failed to update fwmark:", err)
+                                       return &IPCError{Code: ipcErrorPortInUse}
+                               }
+
                        case "public_key":
                                // switch to peer configuration
                                deviceConfig = false