]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Improved timer code
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 27 Jul 2017 21:45:37 +0000 (23:45 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 27 Jul 2017 21:45:37 +0000 (23:45 +0200)
src/constants.go
src/noise_protocol.go
src/peer.go
src/receive.go
src/send.go
src/timers.go

index 6b0d41459a2363773008612b6be855cbcbbdb39d..09d33d858e1128da69a62a9d5294eb12a0cfb35a 100644 (file)
@@ -20,6 +20,7 @@ const (
 
 const (
        RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout
+       NewHandshakeTime        = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message
 )
 
 /* Implementation specific constants */
index 5fe6fb2c37376f3f29e24d1026c093c0e371a74a..e2ff5736eb968bcc9b34e7ff404ecfc3b01687ce 100644 (file)
@@ -37,6 +37,7 @@ const (
        MessageCookieReplySize     = 64
        MessageTransportHeaderSize = 16
        MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
+       MessageKeepaliveSize       = MessageTransportSize
 )
 
 const (
@@ -253,8 +254,6 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
                }
                hash = mixHash(hash, msg.Timestamp[:])
 
-               // TODO: check for flood attack
-
                // check for replay attack
 
                return timestamp.After(handshake.lastTimestamp)
index 8eea92946eee8898b8aa0b0d2409c9def5de7a96..9136959d6b4c192225a8027cf6795a932a0bc4f1 100644 (file)
@@ -40,21 +40,22 @@ type Peer struct {
                stop               chan struct{} // (size 0) : close to stop all goroutines for peer
        }
        timer struct {
-               /* Both keep-alive timers acts as one (see timers.go)
-                * They are kept seperate to simplify the implementation.
-                */
                keepalivePersistent *time.Timer // set for persistent keepalives
                keepalivePassive    *time.Timer // set upon recieving messages
-               zeroAllKeys         *time.Timer // zero all key material after RejectAfterTime*3
+               newHandshake        *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
+               zeroAllKeys         *time.Timer // zero all key material (after RejectAfterTime*3)
+
+               pendingKeepalivePassive bool
+               pendingNewHandshake     bool
+               pendingZeroAllKeys      bool
+
+               needAnotherKeepalive bool
        }
        queue struct {
                nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
                outbound chan *QueueOutboundElement // sequential ordering of work
                inbound  chan *QueueInboundElement  // sequential ordering of work
        }
-       flags struct {
-               keepaliveWaiting int32
-       }
        mac MACStatePeer
 }
 
@@ -68,12 +69,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.mac.Init(pk)
        peer.device = device
 
-       peer.timer.keepalivePassive = NewStoppedTimer()
        peer.timer.keepalivePersistent = NewStoppedTimer()
+       peer.timer.keepalivePassive = NewStoppedTimer()
+       peer.timer.newHandshake = NewStoppedTimer()
        peer.timer.zeroAllKeys = NewStoppedTimer()
 
-       peer.flags.keepaliveWaiting = AtomicFalse
-
        // assign id for debugging
 
        device.mutex.Lock()
index d97ca41de55b3b0ef4b3e77b823ddb47ec8ee29e..c74211b99aeaca35ee4d201527121dd7a693a885 100644 (file)
@@ -288,6 +288,7 @@ func (device *Device) RoutineHandshake() {
        logDebug := device.log.Debug
        logDebug.Println("Routine, handshake routine, started for device")
 
+       var temp [256]byte
        var elem QueueHandshakeElement
 
        for {
@@ -363,6 +364,7 @@ func (device *Device) RoutineHandshake() {
                                        )
                                        return
                                }
+                               peer.TimerPacketReceived()
 
                                // update endpoint
 
@@ -378,17 +380,19 @@ func (device *Device) RoutineHandshake() {
                                        return
                                }
 
+                               peer.TimerEphemeralKeyCreated()
+
                                logDebug.Println("Creating response message for", peer.String())
 
-                               outElem := device.NewOutboundElement()
-                               writer := bytes.NewBuffer(outElem.buffer[:0])
+                               writer := bytes.NewBuffer(temp[:0])
                                binary.Write(writer, binary.LittleEndian, response)
-                               outElem.packet = writer.Bytes()
-                               peer.mac.AddMacs(outElem.packet)
-                               addToOutboundQueue(peer.queue.outbound, outElem)
+                               packet := writer.Bytes()
+                               peer.mac.AddMacs(packet)
 
-                               // create new keypair
+                               // send response
 
+                               peer.SendBuffer(packet)
+                               peer.TimerPacketSent()
                                peer.NewKeyPair()
 
                        case MessageResponseType:
@@ -418,12 +422,11 @@ func (device *Device) RoutineHandshake() {
                                        )
                                        return
                                }
-                               kp := peer.NewKeyPair()
-                               if kp == nil {
-                                       logDebug.Println("Failed to derieve key-pair")
-                               }
+
+                               peer.TimerPacketReceived()
+                               peer.TimerHandshakeComplete()
+                               peer.NewKeyPair()
                                peer.SendKeepAlive()
-                               peer.EventHandshakeComplete()
 
                        default:
                                logError.Println("Invalid message type in handshake queue")
@@ -464,12 +467,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                return
                        }
 
-                       // time (passive) keep-alive
-
-                       peer.TimerStartKeepalive()
-
-                       // refresh key material (rekey)
-
+                       peer.TimerPacketReceived()
+                       peer.TimerTransportReceived()
                        peer.KeepKeyFreshReceiving()
 
                        // check if using new key-pair
@@ -477,7 +476,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        kp := &peer.keyPairs
                        kp.mutex.Lock()
                        if kp.next == elem.keyPair {
-                               peer.EventHandshakeComplete()
+                               peer.TimerHandshakeComplete()
                                kp.previous = kp.current
                                kp.current = kp.next
                                kp.next = nil
@@ -490,6 +489,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                logDebug.Println("Received keep-alive from", peer.String())
                                return
                        }
+                       peer.TimerDataReceived()
 
                        // verify source and strip padding
 
index 7cdb806726976787b6dc35a458bead601780cb37..37078b97680b1920d12670dbcf3bad0b90a00058 100644 (file)
@@ -2,6 +2,7 @@ package main
 
 import (
        "encoding/binary"
+       "errors"
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
@@ -51,6 +52,11 @@ func (peer *Peer) FlushNonceQueue() {
        }
 }
 
+var (
+       ErrorNoEndpoint   = errors.New("No known endpoint for peer")
+       ErrorNoConnection = errors.New("No UDP socket for device")
+)
+
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
        return &QueueOutboundElement{
                dropped: AtomicFalse,
@@ -103,6 +109,25 @@ func addToEncryptionQueue(
        }
 }
 
+func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+
+       peer.mutex.RLock()
+       endpoint := peer.endpoint
+       peer.mutex.RUnlock()
+       if endpoint == nil {
+               return 0, ErrorNoEndpoint
+       }
+
+       peer.device.net.mutex.RLock()
+       conn := peer.device.net.conn
+       peer.device.net.mutex.RUnlock()
+       if conn == nil {
+               return 0, ErrorNoConnection
+       }
+
+       return conn.WriteToUDP(buffer, endpoint)
+}
+
 /* Reads packets from the TUN and inserts
  * into nonce queue for peer
  *
@@ -349,42 +374,27 @@ func (peer *Peer) RoutineSequentialSender() {
 
                case elem := <-peer.queue.outbound:
                        elem.mutex.Lock()
+                       if elem.IsDropped() {
+                               continue
+                       }
 
-                       func() {
-                               if elem.IsDropped() {
-                                       return
-                               }
-
-                               // get endpoint and connection
-
-                               peer.mutex.RLock()
-                               endpoint := peer.endpoint
-                               peer.mutex.RUnlock()
-                               if endpoint == nil {
-                                       logDebug.Println("No endpoint for", peer.String())
-                                       return
-                               }
-
-                               device.net.mutex.RLock()
-                               conn := device.net.conn
-                               device.net.mutex.RUnlock()
-                               if conn == nil {
-                                       logDebug.Println("No source for device")
-                                       return
-                               }
-
-                               // send message and refresh keys
+                       // send message and return buffer to pool
 
-                               _, err := conn.WriteToUDP(elem.packet, endpoint)
-                               if err != nil {
-                                       return
-                               }
+                       length := uint64(len(elem.packet))
+                       _, err := peer.SendBuffer(elem.packet)
+                       device.PutMessageBuffer(elem.buffer)
+                       if err != nil {
+                               continue
+                       }
+                       atomic.AddUint64(&peer.stats.txBytes, length)
 
-                               atomic.AddUint64(&peer.stats.txBytes, uint64(len(elem.packet)))
-                               peer.TimerResetKeepalive()
-                       }()
+                       // update timers
 
-                       device.PutMessageBuffer(elem.buffer)
+                       peer.TimerPacketSent()
+                       if len(elem.packet) != MessageKeepaliveSize {
+                               peer.TimerDataSent()
+                       }
+                       peer.KeepKeyFreshSending()
                }
        }
 }
index 24544147e7730e2cbd1d27f4d8487e19b8a2d6b6..5a16e9bfe3280352b84c10618c3ed5ff30db52bb 100644 (file)
@@ -44,21 +44,6 @@ func (peer *Peer) KeepKeyFreshReceiving() {
        }
 }
 
-/* Called after succesfully completing a handshake.
- * i.e. after:
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-func (peer *Peer) EventHandshakeComplete() {
-       peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
-       peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-       atomic.StoreInt64(
-               &peer.stats.lastHandshakeNano,
-               time.Now().UnixNano(),
-       )
-       signalSend(peer.signal.handshakeCompleted)
-}
-
 /* Queues a keep-alive if no packets are queued for peer
  */
 func (peer *Peer) SendKeepAlive() bool {
@@ -75,69 +60,89 @@ func (peer *Peer) SendKeepAlive() bool {
        return true
 }
 
-/* Starts the "keep-alive" timer
- * (if not already running),
- * in response to incomming messages
+/* Authenticated data packet send
+ * Always called together with peer.EventPacketSend
+ *
+ * - Start new handshake timer
  */
-func (peer *Peer) TimerStartKeepalive() {
-
-       // check if acknowledgement timer set yet
+func (peer *Peer) TimerDataSent() {
+       timerStop(peer.timer.keepalivePassive)
+       if !peer.timer.pendingNewHandshake {
+               peer.timer.pendingNewHandshake = true
+               peer.timer.newHandshake.Reset(NewHandshakeTime)
+       }
+}
 
-       var waiting int32 = AtomicTrue
-       waiting = atomic.SwapInt32(&peer.flags.keepaliveWaiting, waiting)
-       if waiting == AtomicTrue {
+/* Event:
+ * Received non-empty (authenticated) transport message
+ *
+ * - Start passive keep-alive timer
+ */
+func (peer *Peer) TimerDataReceived() {
+       if peer.timer.pendingKeepalivePassive {
+               peer.timer.needAnotherKeepalive = true
                return
        }
+       peer.timer.pendingKeepalivePassive = false
+       peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
+}
 
-       // timer not yet set, start it
+/* Event:
+ * Any (authenticated) transport message received
+ * (keep-alive or data)
+ */
+func (peer *Peer) TimerTransportReceived() {
+       timerStop(peer.timer.newHandshake)
+}
 
-       wait := KeepaliveTimeout
+/* Event:
+ * Any packet send to the peer.
+ */
+func (peer *Peer) TimerPacketSent() {
        interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
        if interval > 0 {
                duration := time.Duration(interval) * time.Second
-               if duration < wait {
-                       wait = duration
-               }
+               peer.timer.keepalivePersistent.Reset(duration)
        }
 }
 
-/* Resets both keep-alive timers
+/* Event:
+ * Any authenticated packet received from peer
  */
-func (peer *Peer) TimerResetKeepalive() {
-
-       // reset persistent timer
-
-       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
-       if interval > 0 {
-               peer.timer.keepalivePersistent.Reset(
-                       time.Duration(interval) * time.Second,
-               )
-       }
-
-       // stop acknowledgement timer
-
-       timerStop(peer.timer.keepalivePassive)
-       atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
+func (peer *Peer) TimerPacketReceived() {
+       peer.TimerPacketSent()
 }
 
-func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
-
-       // create initiation
+/* Called after succesfully completing a handshake.
+ * i.e. after:
+ *
+ * - Valid handshake response
+ * - First transport message under the "next" key
+ */
+func (peer *Peer) TimerHandshakeComplete() {
+       timerStop(peer.timer.zeroAllKeys)
+       atomic.StoreInt64(
+               &peer.stats.lastHandshakeNano,
+               time.Now().UnixNano(),
+       )
+       signalSend(peer.signal.handshakeCompleted)
+       peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
+}
 
-       elem := peer.device.NewOutboundElement()
-       msg, err := peer.device.CreateMessageInitiation(peer)
-       if err != nil {
-               return nil, err
+/* Called whenever an ephemeral key is generated
+ * i.e after:
+ *
+ * CreateMessageInitiation
+ * CreateMessageResponse
+ *
+ * Schedules the deletion of all key material
+ * upon failure to complete a handshake
+ */
+func (peer *Peer) TimerEphemeralKeyCreated() {
+       if !peer.timer.pendingZeroAllKeys {
+               peer.timer.pendingZeroAllKeys = true
+               peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
        }
-
-       // marshal & schedule for sending
-
-       writer := bytes.NewBuffer(elem.buffer[:0])
-       binary.Write(writer, binary.LittleEndian, msg)
-       elem.packet = writer.Bytes()
-       peer.mac.AddMacs(elem.packet)
-       addToOutboundQueue(peer.queue.outbound, elem)
-       return elem, err
 }
 
 func (peer *Peer) RoutineTimerHandler() {
@@ -157,17 +162,30 @@ func (peer *Peer) RoutineTimerHandler() {
 
                case <-peer.timer.keepalivePersistent.C:
 
-                       logDebug.Println("Sending persistent keep-alive to", peer.String())
-
-                       peer.SendKeepAlive()
-                       peer.TimerResetKeepalive()
+                       interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
+                       if interval > 0 {
+                               logDebug.Println("Sending persistent keep-alive to", peer.String())
+                               peer.SendKeepAlive()
+                       }
 
                case <-peer.timer.keepalivePassive.C:
 
-                       logDebug.Println("Sending passive persistent keep-alive to", peer.String())
+                       logDebug.Println("Sending passive keep-alive to", peer.String())
 
                        peer.SendKeepAlive()
-                       peer.TimerResetKeepalive()
+
+                       if peer.timer.needAnotherKeepalive {
+                               peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
+                               peer.timer.needAnotherKeepalive = true
+                       }
+
+               // unresponsive session
+
+               case <-peer.timer.newHandshake.C:
+
+                       logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
+
+                       signalSend(peer.signal.handshakeBegin)
 
                // clear key material
 
@@ -175,13 +193,15 @@ func (peer *Peer) RoutineTimerHandler() {
 
                        logDebug.Println("Clearing all key material for", peer.String())
 
+                       hs := &peer.handshake
+                       hs.mutex.Lock()
+
                        kp := &peer.keyPairs
                        kp.mutex.Lock()
 
-                       hs := &peer.handshake
-                       hs.mutex.Lock()
+                       peer.timer.pendingZeroAllKeys = false
 
-                       // unmap local indecies
+                       // unmap indecies
 
                        indices.mutex.Lock()
                        if kp.previous != nil {
@@ -224,80 +244,103 @@ func (peer *Peer) RoutineTimerHandler() {
 func (peer *Peer) RoutineHandshakeInitiator() {
        device := peer.device
 
-       var elem *QueueOutboundElement
-
        logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
        logDebug.Println("Routine, handshake initator, started for", peer.String())
 
+       var temp [256]byte
+
        for {
 
                // wait for signal
 
                select {
                case <-peer.signal.handshakeBegin:
+                       signalSend(peer.signal.handshakeBegin)
                case <-peer.signal.stop:
                        return
                }
 
                // wait for handshake
 
-               func() {
-                       var err error
-                       var deadline time.Time
-                       for attempts := uint(1); ; attempts++ {
-
-                               // clear completed signal
-
-                               select {
-                               case <-peer.signal.handshakeCompleted:
-                               case <-peer.signal.stop:
-                                       return
-                               default:
-                               }
-
-                               // create initiation
-
-                               if elem != nil {
-                                       elem.Drop()
-                               }
-                               elem, err = peer.BeginHandshakeInitiation()
-                               if err != nil {
-                                       logError.Println("Failed to create initiation message", err, "for", peer.String())
-                                       return
-                               }
-
-                               // set timeout
-
-                               if attempts == 1 {
-                                       deadline = time.Now().Add(MaxHandshakeAttemptTime)
-                               }
-                               timeout := time.NewTimer(RekeyTimeout)
-                               logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
-
-                               // wait for handshake or timeout
-
-                               select {
-
-                               case <-peer.signal.stop:
-                                       return
-
-                               case <-peer.signal.handshakeCompleted:
-                                       <-timeout.C
-                                       return
-
-                               case <-timeout.C:
-                                       if deadline.Before(time.Now().Add(RekeyTimeout)) {
-                                               logInfo.Println("Handshake negotiation timed out for", peer.String())
-                                               signalSend(peer.signal.flushNonceQueue)
-                                               timerStop(peer.timer.keepalivePersistent)
-                                               timerStop(peer.timer.keepalivePassive)
-                                               return
-                                       }
-                               }
+               deadline := time.Now().Add(MaxHandshakeAttemptTime)
+
+       Loop:
+               for attempts := uint(1); ; attempts++ {
+
+                       // clear completed signal
+
+                       select {
+                       case <-peer.signal.handshakeCompleted:
+                       case <-peer.signal.stop:
+                               return
+                       default:
+                       }
+
+                       // check if sufficient time for retry
+
+                       if deadline.Before(time.Now().Add(RekeyTimeout)) {
+                               logInfo.Println("Handshake negotiation timed out for", peer.String())
+                               signalSend(peer.signal.flushNonceQueue)
+                               timerStop(peer.timer.keepalivePersistent)
+                               timerStop(peer.timer.keepalivePassive)
+                               break Loop
+                       }
+
+                       // create initiation message
+
+                       msg, err := peer.device.CreateMessageInitiation(peer)
+                       if err != nil {
+                               logError.Println("Failed to create handshake initiation message:", err)
+                               break Loop
+                       }
+                       peer.TimerEphemeralKeyCreated()
+
+                       // marshal and send
+
+                       writer := bytes.NewBuffer(temp[:0])
+                       binary.Write(writer, binary.LittleEndian, msg)
+                       packet := writer.Bytes()
+                       peer.mac.AddMacs(packet)
+                       peer.TimerPacketSent()
+
+                       _, err = peer.SendBuffer(packet)
+                       if err != nil {
+                               logError.Println(
+                                       "Failed to send handshake initiation message to",
+                                       peer.String(), ":", err,
+                               )
+                               continue
+                       }
+
+                       // set timeout
+
+                       timeout := time.NewTimer(RekeyTimeout)
+                       logDebug.Println(
+                               "Handshake initiation attempt",
+                               attempts, "sent to", peer.String(),
+                       )
+
+                       // wait for handshake or timeout
+
+                       select {
+
+                       case <-peer.signal.stop:
+                               return
+
+                       case <-peer.signal.handshakeCompleted:
+                               <-timeout.C
+                               break Loop
+
+                       case <-timeout.C:
+                               continue
+
                        }
-               }()
+
+               }
+
+               // allow new signal to be set
 
                signalClear(peer.signal.handshakeBegin)
        }