]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Rewrite timers and related state machines
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 7 May 2018 20:27:03 +0000 (22:27 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 10 May 2018 14:08:03 +0000 (16:08 +0200)
14 files changed:
constants.go
device.go
event.go [deleted file]
index.go
keypair.go
main.go
noise-protocol.go
noise_test.go
peer.go
receive.go
send.go
signal.go [deleted file]
timers.go
uapi.go

index 04b75d785ffaeac580b2b607206663bb561554ba..01af1bb561c7ea738e69f9c19c7a33d627af6dd8 100644 (file)
@@ -12,21 +12,18 @@ import (
 /* Specification constants */
 
 const (
-       RekeyAfterMessages     = (1 << 64) - (1 << 16) - 1
-       RejectAfterMessages    = (1 << 64) - (1 << 4) - 1
-       RekeyAfterTime         = time.Second * 120
-       RekeyAttemptTime       = time.Second * 90
-       RekeyTimeout           = time.Second * 5
-       RejectAfterTime        = time.Second * 180
-       KeepaliveTimeout       = time.Second * 10
-       CookieRefreshTime      = time.Second * 120
-       HandshakeInitationRate = time.Second / 20
-       PaddingMultiple        = 16
-)
-
-const (
-       RekeyAfterTimeReceiving = RejectAfterTime - KeepaliveTimeout - RekeyTimeout
-       NewHandshakeTime        = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message
+       RekeyAfterMessages      = (1 << 64) - (1 << 16) - 1
+       RejectAfterMessages     = (1 << 64) - (1 << 4) - 1
+       RekeyAfterTime          = time.Second * 120
+       RekeyAttemptTime        = time.Second * 90
+       RekeyTimeout            = time.Second * 5
+       MaxTimerHandshakes      = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
+       RekeyTimeoutJitterMaxMs = 334
+       RejectAfterTime         = time.Second * 180
+       KeepaliveTimeout        = time.Second * 10
+       CookieRefreshTime       = time.Second * 120
+       HandshakeInitationRate  = time.Second / 20
+       PaddingMultiple         = 16
 )
 
 /* Implementation specific constants */
index c714b213776f25ff36ad370696b456ca0efa2e9a..e127b5b376da50eb6d8af6b93ea30c5ec0a0c94c 100644 (file)
--- a/device.go
+++ b/device.go
@@ -74,8 +74,8 @@ type Device struct {
                handshake  chan QueueHandshakeElement
        }
 
-       signal struct {
-               stop Signal
+       signals struct {
+               stop chan struct{}
        }
 
        tun struct {
@@ -302,7 +302,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
 
        // prepare signals
 
-       device.signal.stop = NewSignal()
+       device.signals.stop = make(chan struct{}, 1)
 
        // prepare net
 
@@ -400,7 +400,7 @@ func (device *Device) Close() {
 
        device.isUp.Set(false)
 
-       device.signal.stop.Broadcast()
+       close(device.signals.stop)
 
        device.state.stopping.Wait()
        device.FlushPacketQueues()
@@ -413,5 +413,5 @@ func (device *Device) Close() {
 }
 
 func (device *Device) Wait() chan struct{} {
-       return device.signal.stop.Wait()
+       return device.signals.stop
 }
diff --git a/event.go b/event.go
deleted file mode 100644 (file)
index 6235ba4..0000000
--- a/event.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package main
-
-import (
-       "sync/atomic"
-       "time"
-)
-
-type Event struct {
-       guard    int32
-       next     time.Time
-       interval time.Duration
-       C        chan struct{}
-}
-
-func newEvent(interval time.Duration) *Event {
-       return &Event{
-               guard:    0,
-               next:     time.Now(),
-               interval: interval,
-               C:        make(chan struct{}, 1),
-       }
-}
-
-func (e *Event) Clear() {
-       select {
-       case <-e.C:
-       default:
-       }
-}
-
-func (e *Event) Fire() {
-       if e == nil || atomic.SwapInt32(&e.guard, 1) != 0 {
-               return
-       }
-       if now := time.Now(); now.After(e.next) {
-               select {
-               case e.C <- struct{}{}:
-               default:
-               }
-               e.next = now.Add(e.interval)
-       }
-       atomic.StoreInt32(&e.guard, 0)
-}
index c309f234e39bcf00cccec975a035ddca7eafdbf7..4a78d556b849b95a90a1a120cdccfeffdef8d42c 100644 (file)
--- a/index.go
+++ b/index.go
@@ -18,7 +18,7 @@ import (
 type IndexTableEntry struct {
        peer      *Peer
        handshake *Handshake
-       keyPair   *KeyPair
+       keyPair   *Keypair
 }
 
 type IndexTable struct {
index eaf30b2e56f78b516014ba3fcafd627cfda6df15..07a183db6531debc57c8009a2d0a5f8f3d65ca14 100644 (file)
@@ -18,7 +18,7 @@ import (
  * we plan to resolve this issue; whenever Go allows us to do so.
  */
 
-type KeyPair struct {
+type Keypair struct {
        sendNonce    uint64
        send         cipher.AEAD
        receive      cipher.AEAD
@@ -29,20 +29,20 @@ type KeyPair struct {
        remoteIndex  uint32
 }
 
-type KeyPairs struct {
+type Keypairs struct {
        mutex    sync.RWMutex
-       current  *KeyPair
-       previous *KeyPair
-       next     *KeyPair // not yet "confirmed by transport"
+       current  *Keypair
+       previous *Keypair
+       next     *Keypair // not yet "confirmed by transport"
 }
 
-func (kp *KeyPairs) Current() *KeyPair {
+func (kp *Keypairs) Current() *Keypair {
        kp.mutex.RLock()
        defer kp.mutex.RUnlock()
        return kp.current
 }
 
-func (device *Device) DeleteKeyPair(key *KeyPair) {
+func (device *Device) DeleteKeypair(key *Keypair) {
        if key != nil {
                device.indices.Delete(key.localIndex)
        }
diff --git a/main.go b/main.go
index ecfbc502b6bc9b811634c5334add853585decb89..5001bc436e1d918870e7427fef989d1b48ef2208 100644 (file)
--- a/main.go
+++ b/main.go
@@ -30,6 +30,8 @@ func printUsage() {
 }
 
 func warning() {
+       shouldQuit := false
+
        fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
        fmt.Fprintln(os.Stderr, "W                                                     G")
        fmt.Fprintln(os.Stderr, "W   This is alpha software. It will very likely not   G")
@@ -37,6 +39,8 @@ func warning() {
        fmt.Fprintln(os.Stderr, "W   horribly wrong. You have been warned. Proceed     G")
        fmt.Fprintln(os.Stderr, "W   at your own risk.                                 G")
        if runtime.GOOS == "linux" {
+               shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
+
                fmt.Fprintln(os.Stderr, "W                                                     G")
                fmt.Fprintln(os.Stderr, "W   Furthermore, you are running this software on a   G")
                fmt.Fprintln(os.Stderr, "W   Linux kernel, which is probably unnecessary and   G")
@@ -46,9 +50,20 @@ func warning() {
                fmt.Fprintln(os.Stderr, "W   program. For more information on installing the   G")
                fmt.Fprintln(os.Stderr, "W   kernel module, please visit:                      G")
                fmt.Fprintln(os.Stderr, "W           https://www.wireguard.com/install         G")
+               if shouldQuit {
+                       fmt.Fprintln(os.Stderr, "W                                                     G")
+                       fmt.Fprintln(os.Stderr, "W   If you still want to use this program, against    G")
+                       fmt.Fprintln(os.Stderr, "W   the sage advice here, please first export this    G")
+                       fmt.Fprintln(os.Stderr, "W   environment variable:                             G")
+                       fmt.Fprintln(os.Stderr, "W   WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1    G")
+               }
        }
        fmt.Fprintln(os.Stderr, "W                                                     G")
        fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
+
+       if shouldQuit {
+               os.Exit(1)
+       }
 }
 
 func main() {
index 35e95efdb4b61f86ff740647f41fd74b90f12619..3abbe4bec43cc806b0529659fa11c15acf25ebad 100644 (file)
@@ -1,6 +1,6 @@
 /* SPDX-License-Identifier: GPL-2.0
  *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
  */
 
 package main
@@ -488,7 +488,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 /* Derives a new key-pair from the current handshake state
  *
  */
-func (peer *Peer) NewKeyPair() *KeyPair {
+func (peer *Peer) NewKeypair() *Keypair {
        device := peer.device
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -528,7 +528,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
        // create AEAD instances
 
-       keyPair := new(KeyPair)
+       keyPair := new(Keypair)
        keyPair.send, _ = chacha20poly1305.New(sendKey[:])
        keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
 
@@ -559,24 +559,27 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        kp := &peer.keyPairs
        kp.mutex.Lock()
 
-       if isInitiator {
-               if kp.previous != nil {
-                       device.DeleteKeyPair(kp.previous)
-                       kp.previous = nil
-               }
+       peer.timersSessionDerived()
+
+       previous := kp.previous
+       next := kp.next
+       current := kp.current
 
-               if kp.next != nil {
-                       kp.previous = kp.next
-                       kp.next = keyPair
+       if isInitiator {
+               if next != nil {
+                       kp.next = nil
+                       kp.previous = next
+                       device.DeleteKeypair(current)
                } else {
-                       kp.previous = kp.current
-                       kp.current = keyPair
-                       peer.event.newKeyPair.Fire()
+                       kp.previous = current
                }
-
+               device.DeleteKeypair(previous)
+               kp.current = keyPair
        } else {
                kp.next = keyPair
+               device.DeleteKeypair(next)
                kp.previous = nil
+               device.DeleteKeypair(previous)
        }
        kp.mutex.Unlock()
 
index 958a4effa9278c1c7fb9d950c4f144ffe8f4eed0..37bfb94710062d2fff63083e666f0d186358ff08 100644 (file)
@@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) {
 
        t.Log("deriving keys")
 
-       key1 := peer1.NewKeyPair()
-       key2 := peer2.NewKeyPair()
+       key1 := peer1.NewKeypair()
+       key2 := peer2.NewKeypair()
 
        if key1 == nil {
                t.Fatal("failed to dervice key-pair for peer 1")
diff --git a/peer.go b/peer.go
index 739c8fba30440832251dbbfdf303d18cde7cf661..242729ebd222b823a628faa171eab8fca99cde7b 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -14,14 +14,13 @@ import (
 )
 
 const (
-       PeerRoutineNumber = 4
-       EventInterval     = 10 * time.Millisecond
+       PeerRoutineNumber = 3
 )
 
 type Peer struct {
        isRunning                   AtomicBool
        mutex                       sync.RWMutex
-       keyPairs                    KeyPairs
+       keyPairs                    Keypairs
        handshake                   Handshake
        device                      *Device
        endpoint                    Endpoint
@@ -34,34 +33,28 @@ type Peer struct {
                lastHandshakeNano int64  // nano seconds since epoch
        }
 
-       time struct {
-               mutex         sync.RWMutex
-               lastSend      time.Time // last send message
-               lastHandshake time.Time // last completed handshake
-               nextKeepalive time.Time
+       timers struct {
+               retransmitHandshake     *Timer
+               sendKeepalive           *Timer
+               newHandshake            *Timer
+               zeroKeyMaterial         *Timer
+               persistentKeepalive     *Timer
+               handshakeAttempts       uint
+               needAnotherKeepalive    bool
+               sentLastMinuteHandshake bool
+               lastSentHandshake       time.Time
        }
 
-       event struct {
-               dataSent                        *Event
-               dataReceived                    *Event
-               anyAuthenticatedPacketReceived  *Event
-               anyAuthenticatedPacketTraversal *Event
-               handshakeCompleted              *Event
-               handshakePushDeadline           *Event
-               handshakeBegin                  *Event
-               ephemeralKeyCreated             *Event
-               newKeyPair                      *Event
-               flushNonceQueue                 *Event
-       }
-
-       timer struct {
-               sendLastMinuteHandshake AtomicBool
+       signals struct {
+               newKeypairArrived chan struct{}
+               flushNonceQueue   chan struct{}
        }
 
        queue struct {
-               nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
-               outbound chan *QueueOutboundElement // sequential ordering of work
-               inbound  chan *QueueInboundElement  // sequential ordering of work
+               nonce                           chan *QueueOutboundElement // nonce / pre-handshake queue
+               outbound                        chan *QueueOutboundElement // sequential ordering of work
+               inbound                         chan *QueueInboundElement  // sequential ordering of work
+               packetInNonceQueueIsAwaitingKey bool
        }
 
        routines struct {
@@ -188,6 +181,8 @@ func (peer *Peer) Start() {
        peer.routines.starting.Wait()
        peer.routines.stopping.Wait()
        peer.routines.stop = make(chan struct{})
+       peer.routines.starting.Add(PeerRoutineNumber)
+       peer.routines.stopping.Add(PeerRoutineNumber)
 
        // prepare queues
 
@@ -195,28 +190,13 @@ func (peer *Peer) Start() {
        peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
        peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
 
-       // events
-
-       peer.event.dataSent = newEvent(EventInterval)
-       peer.event.dataReceived = newEvent(EventInterval)
-       peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval)
-       peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval)
-       peer.event.handshakeCompleted = newEvent(EventInterval)
-       peer.event.handshakePushDeadline = newEvent(EventInterval)
-       peer.event.handshakeBegin = newEvent(EventInterval)
-       peer.event.ephemeralKeyCreated = newEvent(EventInterval)
-       peer.event.newKeyPair = newEvent(EventInterval)
-       peer.event.flushNonceQueue = newEvent(EventInterval)
-
-       peer.isRunning.Set(true)
+       peer.timersInit()
+       peer.signals.newKeypairArrived = make(chan struct{}, 1)
+       peer.signals.flushNonceQueue = make(chan struct{}, 1)
 
        // wait for routines to start
 
-       peer.routines.starting.Add(PeerRoutineNumber)
-       peer.routines.stopping.Add(PeerRoutineNumber)
-
        go peer.RoutineNonce()
-       go peer.RoutineTimerHandler()
        go peer.RoutineSequentialSender()
        go peer.RoutineSequentialReceiver()
 
@@ -238,6 +218,8 @@ func (peer *Peer) Stop() {
        device := peer.device
        device.log.Debug.Println(peer, ": Stopping...")
 
+       peer.timersStop()
+
        // stop & wait for ongoing peer routines
 
        peer.routines.starting.Wait()
@@ -255,9 +237,9 @@ func (peer *Peer) Stop() {
        kp := &peer.keyPairs
        kp.mutex.Lock()
 
-       device.DeleteKeyPair(kp.previous)
-       device.DeleteKeyPair(kp.current)
-       device.DeleteKeyPair(kp.next)
+       device.DeleteKeypair(kp.previous)
+       device.DeleteKeypair(kp.current)
+       device.DeleteKeypair(kp.next)
 
        kp.previous = nil
        kp.current = nil
@@ -271,4 +253,6 @@ func (peer *Peer) Stop() {
        device.indices.Delete(hs.localIndex)
        hs.Clear()
        hs.mutex.Unlock()
+
+       peer.FlushNonceQueue()
 }
index 1cf77b26871f2a896d8fb56f3bfaa22804498a51..0f22a3f252959ef091405a44654e36dc3daa79f9 100644 (file)
@@ -31,7 +31,7 @@ type QueueInboundElement struct {
        buffer   *[MaxMessageSize]byte
        packet   []byte
        counter  uint64
-       keyPair  *KeyPair
+       keyPair  *Keypair
        endpoint Endpoint
 }
 
@@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue(
        }
 }
 
+/* Called when a new authenticated message has been received
+ *
+ * NOTE: Not thread safe, but called by sequential receiver!
+ */
+func (peer *Peer) keepKeyFreshReceiving() {
+       if peer.timers.sentLastMinuteHandshake {
+               return
+       }
+       kp := peer.keyPairs.Current()
+       if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
+               peer.timers.sentLastMinuteHandshake = true
+               peer.SendHandshakeInitiation(false)
+       }
+}
+
 /* Receives incoming datagrams for the device
  *
  * Every time the bind is updated a new routine is started for
@@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() {
 
        for {
                select {
-               case <-device.signal.stop.Wait():
+               case <-device.signals.stop:
                        return
 
                case elem, ok := <-device.queue.decryption:
@@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() {
        for {
                select {
                case elem, ok = <-device.queue.handshake:
-               case <-device.signal.stop.Wait():
+               case <-device.signals.stop:
                        return
                }
 
@@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() {
 
                        // update timers
 
-                       peer.event.anyAuthenticatedPacketTraversal.Fire()
-                       peer.event.anyAuthenticatedPacketReceived.Fire()
+                       peer.timersAnyAuthenticatedPacketTraversal()
+                       peer.timersAnyAuthenticatedPacketReceived()
 
                        // update endpoint
 
@@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() {
                                continue
                        }
 
-                       peer.TimerEphemeralKeyCreated()
-                       peer.NewKeyPair()
+                       if peer.NewKeypair() == nil {
+                               continue
+                       }
 
-                       logDebug.Println(peer, ": Creating handshake response")
+                       logDebug.Println(peer, ": Sending handshake response")
 
                        writer := bytes.NewBuffer(temp[:0])
                        binary.Write(writer, binary.LittleEndian, response)
@@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() {
 
                        // send response
 
+                       peer.timers.lastSentHandshake = time.Now()
                        err = peer.SendBuffer(packet)
                        if err == nil {
-                               peer.event.anyAuthenticatedPacketTraversal.Fire()
+                               peer.timersAnyAuthenticatedPacketTraversal()
                        } else {
                                logError.Println(peer, ": Failed to send handshake response", err)
                        }
@@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() {
 
                        logDebug.Println(peer, ": Received handshake response")
 
-                       peer.TimerEphemeralKeyCreated()
-
                        // update timers
 
-                       peer.event.anyAuthenticatedPacketTraversal.Fire()
-                       peer.event.anyAuthenticatedPacketReceived.Fire()
-                       peer.event.handshakeCompleted.Fire()
+                       peer.timersAnyAuthenticatedPacketTraversal()
+                       peer.timersAnyAuthenticatedPacketReceived()
 
                        // derive key-pair
 
-                       peer.NewKeyPair()
-                       peer.SendKeepAlive()
+                       if peer.NewKeypair() == nil {
+                               continue
+                       }
+
+                       peer.timersHandshakeComplete()
+                       peer.SendKeepalive()
+                       select {
+                       case peer.signals.newKeypairArrived <- struct{}{}:
+                       default:
+                       }
                }
        }
 }
@@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                continue
                        }
 
-                       peer.event.anyAuthenticatedPacketTraversal.Fire()
-                       peer.event.anyAuthenticatedPacketReceived.Fire()
-                       peer.KeepKeyFreshReceiving()
+                       // update endpoint
+
+                       peer.mutex.Lock()
+                       peer.endpoint = elem.endpoint
+                       peer.mutex.Unlock()
 
                        // check if using new key-pair
 
                        kp := &peer.keyPairs
-                       kp.mutex.Lock()
+                       kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
                        if kp.next == elem.keyPair {
-                               peer.event.handshakeCompleted.Fire()
-                               if kp.previous != nil {
-                                       device.DeleteKeyPair(kp.previous)
-                               }
+                               old := kp.previous
                                kp.previous = kp.current
+                               device.DeleteKeypair(old)
                                kp.current = kp.next
                                kp.next = nil
+                               peer.timersHandshakeComplete()
+                               select {
+                               case peer.signals.newKeypairArrived <- struct{}{}:
+                               default:
+                               }
                        }
                        kp.mutex.Unlock()
 
-                       // update endpoint
-
-                       peer.mutex.Lock()
-                       peer.endpoint = elem.endpoint
-                       peer.mutex.Unlock()
+                       peer.keepKeyFreshReceiving()
+                       peer.timersAnyAuthenticatedPacketTraversal()
+                       peer.timersAnyAuthenticatedPacketReceived()
 
-                       // check for keep-alive
+                       // check for keepalive
 
                        if len(elem.packet) == 0 {
-                               logDebug.Println(peer, ": Received keep-alive")
+                               logDebug.Println(peer, ": Receiving keepalive packet")
                                continue
                        }
-                       peer.event.dataReceived.Fire()
+                       peer.timersDataReceived()
 
                        // verify source and strip padding
 
diff --git a/send.go b/send.go
index ddebb99e50164e3fd2f741bb3437d5f9caf79ad1..1b35e275eae275d89b7137fc810277fa4d60c026 100644 (file)
--- a/send.go
+++ b/send.go
@@ -6,6 +6,7 @@
 package main
 
 import (
+       "bytes"
        "encoding/binary"
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/net/ipv4"
@@ -46,21 +47,10 @@ type QueueOutboundElement struct {
        buffer  *[MaxMessageSize]byte // slice holding the packet data
        packet  []byte                // slice of "buffer" (always!)
        nonce   uint64                // nonce for encryption
-       keyPair *KeyPair              // key-pair for encryption
+       keyPair *Keypair              // key-pair for encryption
        peer    *Peer                 // related peer
 }
 
-func (peer *Peer) flushNonceQueue() {
-       elems := len(peer.queue.nonce)
-       for i := 0; i < elems; i++ {
-               select {
-               case <-peer.queue.nonce:
-               default:
-                       return
-               }
-       }
-}
-
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
        return &QueueOutboundElement{
                dropped: AtomicFalse,
@@ -114,6 +104,73 @@ func addToEncryptionQueue(
        }
 }
 
+/* Queues a keepalive if no packets are queued for peer
+ */
+func (peer *Peer) SendKeepalive() bool {
+       if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey {
+               return false
+       }
+       elem := peer.device.NewOutboundElement()
+       elem.packet = nil
+       select {
+       case peer.queue.nonce <- elem:
+               peer.device.log.Debug.Println(peer, ": Sending keepalive packet")
+               return true
+       default:
+               return false
+       }
+}
+
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
+       if !isRetry {
+               peer.timers.handshakeAttempts = 0
+       }
+
+       if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout {
+               return nil
+       }
+       peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable?
+
+       // create initiation message
+
+       msg, err := peer.device.CreateMessageInitiation(peer)
+       if err != nil {
+               return err
+       }
+
+       peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
+
+       // marshal handshake message
+
+       var buff [MessageInitiationSize]byte
+       writer := bytes.NewBuffer(buff[:0])
+       binary.Write(writer, binary.LittleEndian, msg)
+       packet := writer.Bytes()
+       peer.mac.AddMacs(packet)
+
+       // send to endpoint
+
+       peer.timersAnyAuthenticatedPacketTraversal()
+       peer.timersHandshakeInitiated()
+       return peer.SendBuffer(packet)
+}
+
+/* Called when a new authenticated message has been send
+ *
+ */
+func (peer *Peer) keepKeyFreshSending() {
+       kp := peer.keyPairs.Current()
+       if kp == nil {
+               return
+       }
+       nonce := atomic.LoadUint64(&kp.sendNonce)
+       if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) {
+               peer.SendHandshakeInitiation(false)
+       }
+}
+
 /* Reads packets from the TUN and inserts
  * into nonce queue for peer
  *
@@ -180,13 +237,22 @@ func (device *Device) RoutineReadFromTUN() {
                // insert into nonce/pre-handshake queue
 
                if peer.isRunning.Get() {
-                       peer.event.handshakePushDeadline.Fire()
+                       if peer.queue.packetInNonceQueueIsAwaitingKey {
+                               peer.SendHandshakeInitiation(false)
+                       }
                        addToOutboundQueue(peer.queue.nonce, elem)
                        elem = device.NewOutboundElement()
                }
        }
 }
 
+func (peer *Peer) FlushNonceQueue() {
+       select {
+       case peer.signals.flushNonceQueue <- struct{}{}:
+       default:
+       }
+}
+
 /* Queues packets when there is no handshake.
  * Then assigns nonces to packets sequentially
  * and creates "work" structs for workers
@@ -194,13 +260,14 @@ func (device *Device) RoutineReadFromTUN() {
  * Obs. A single instance per peer
  */
 func (peer *Peer) RoutineNonce() {
-       var keyPair *KeyPair
+       var keyPair *Keypair
 
        device := peer.device
        logDebug := device.log.Debug
 
        defer func() {
                logDebug.Println(peer, ": Routine: nonce worker - stopped")
+               peer.queue.packetInNonceQueueIsAwaitingKey = false
                peer.routines.stopping.Done()
        }()
 
@@ -209,8 +276,7 @@ func (peer *Peer) RoutineNonce() {
 
        for {
        NextPacket:
-
-               peer.event.flushNonceQueue.Clear()
+               peer.queue.packetInNonceQueueIsAwaitingKey = false
 
                select {
                case <-peer.routines.stop:
@@ -225,34 +291,48 @@ func (peer *Peer) RoutineNonce() {
                        // wait for key pair
 
                        for {
-
-                               peer.event.newKeyPair.Clear()
-
                                keyPair = peer.keyPairs.Current()
                                if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
                                        if time.Now().Sub(keyPair.created) < RejectAfterTime {
                                                break
                                        }
                                }
+                               peer.queue.packetInNonceQueueIsAwaitingKey = true
 
-                               peer.event.handshakeBegin.Fire()
+                               select {
+                               case <-peer.signals.newKeypairArrived:
+                               default:
+                               }
+
+                               peer.SendHandshakeInitiation(false)
 
                                logDebug.Println(peer, ": Awaiting key-pair")
 
                                select {
-                               case <-peer.event.newKeyPair.C:
+                               case <-peer.signals.newKeypairArrived:
                                        logDebug.Println(peer, ": Obtained awaited key-pair")
-                               case <-peer.event.flushNonceQueue.C:
-                                       goto NextPacket
+                               case <-peer.signals.flushNonceQueue:
+                                       for {
+                                               select {
+                                               case <-peer.queue.nonce:
+                                               default:
+                                                       goto NextPacket
+                                               }
+                                       }
                                case <-peer.routines.stop:
                                        return
                                }
                        }
+                       peer.queue.packetInNonceQueueIsAwaitingKey = false
 
                        // populate work element
 
                        elem.peer = peer
                        elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+                       // double check in case of race condition added by future code
+                       if elem.nonce >= RejectAfterMessages {
+                               goto NextPacket
+                       }
                        elem.keyPair = keyPair
                        elem.dropped = AtomicFalse
                        elem.mutex.Lock()
@@ -288,7 +368,7 @@ func (device *Device) RoutineEncryption() {
                // fetch next element
 
                select {
-               case <-device.signal.stop.Wait():
+               case <-device.signals.stop:
                        return
 
                case elem, ok := <-device.queue.encryption:
@@ -389,11 +469,11 @@ func (peer *Peer) RoutineSequentialSender() {
 
                        // update timers
 
-                       peer.event.anyAuthenticatedPacketTraversal.Fire()
+                       peer.timersAnyAuthenticatedPacketTraversal()
                        if len(elem.packet) != MessageKeepaliveSize {
-                               peer.event.dataSent.Fire()
+                               peer.timersDataSent()
                        }
-                       peer.KeepKeyFreshSending()
+                       peer.keepKeyFreshSending()
                }
        }
 }
diff --git a/signal.go b/signal.go
deleted file mode 100644 (file)
index 606da52..0000000
--- a/signal.go
+++ /dev/null
@@ -1,71 +0,0 @@
-/* SPDX-License-Identifier: GPL-2.0
- *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-func signalSend(s chan<- struct{}) {
-       select {
-       case s <- struct{}{}:
-       default:
-       }
-}
-
-type Signal struct {
-       enabled AtomicBool
-       C       chan struct{}
-}
-
-func NewSignal() (s Signal) {
-       s.C = make(chan struct{}, 1)
-       s.Enable()
-       return
-}
-
-func (s *Signal) Close() {
-       close(s.C)
-}
-
-func (s *Signal) Disable() {
-       s.enabled.Set(false)
-       s.Clear()
-}
-
-func (s *Signal) Enable() {
-       s.enabled.Set(true)
-}
-
-/* Unblock exactly one listener
- */
-func (s *Signal) Send() {
-       if s.enabled.Get() {
-               select {
-               case s.C <- struct{}{}:
-               default:
-               }
-       }
-}
-
-/* Clear the signal if already fired
- */
-func (s Signal) Clear() {
-       select {
-       case <-s.C:
-       default:
-       }
-}
-
-/* Unblocks all listeners (forever)
- */
-func (s Signal) Broadcast() {
-       if s.enabled.Get() {
-               close(s.C)
-       }
-}
-
-/* Wait for the signal
- */
-func (s Signal) Wait() chan struct{} {
-       return s.C
-}
index 38c9b460731ddf5a763895d5561753866c988cab..5c72efd4428d6c0db01f7f8b7afd9b4596a82b0f 100644 (file)
--- a/timers.go
+++ b/timers.go
 /* SPDX-License-Identifier: GPL-2.0
  *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ *
+ * This is based heavily on timers.c from the kernel implementation.
  */
 
 package main
 
 import (
-       "bytes"
-       "encoding/binary"
        "math/rand"
        "sync/atomic"
        "time"
 )
 
-/* NOTE:
- * Notion of validity
+/* This Timer structure and related functions should roughly copy the interface of
+ * the Linux kernel's struct timer_list.
  */
 
-/* Called when a new authenticated message has been send
- *
- */
-func (peer *Peer) KeepKeyFreshSending() {
-       kp := peer.keyPairs.Current()
-       if kp == nil {
-               return
-       }
-       nonce := atomic.LoadUint64(&kp.sendNonce)
-       if nonce > RekeyAfterMessages {
-               peer.event.handshakeBegin.Fire()
-       }
-       if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
-               peer.event.handshakeBegin.Fire()
-       }
+type Timer struct {
+       timer     *time.Timer
+       isPending bool
 }
 
-/* Called when a new authenticated message has been received
- *
- * NOTE: Not thread safe, but called by sequential receiver!
- */
-func (peer *Peer) KeepKeyFreshReceiving() {
-       if peer.timer.sendLastMinuteHandshake.Get() {
-               return
-       }
-       kp := peer.keyPairs.Current()
-       if kp == nil {
-               return
-       }
-       if !kp.isInitiator {
-               return
-       }
-       nonce := atomic.LoadUint64(&kp.sendNonce)
-       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
-       if send {
-               // do a last minute attempt at initiating a new handshake
-               peer.timer.sendLastMinuteHandshake.Set(true)
-               peer.event.handshakeBegin.Fire()
-       }
+func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
+       timer := &Timer{}
+       timer.timer = time.AfterFunc(time.Hour, func() {
+               timer.isPending = false
+               expirationFunction(peer)
+       })
+       timer.timer.Stop()
+       return timer
 }
 
-/* Queues a keep-alive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepAlive() bool {
-       if len(peer.queue.nonce) != 0 {
-               return false
-       }
-       elem := peer.device.NewOutboundElement()
-       elem.packet = nil
-       select {
-       case peer.queue.nonce <- elem:
-               return true
-       default:
-               return false
-       }
+func (timer *Timer) Mod(d time.Duration) {
+       timer.isPending = true
+       timer.timer.Reset(d)
 }
 
-/* Called after successfully completing a handshake.
- * i.e. after:
- *
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-// peer.device.log.Info.Println(peer, ": New handshake completed")
-
-/* Event:
- * An ephemeral key is generated
- *
- * i.e. after:
- *
- * CreateMessageInitiation
- * CreateMessageResponse
- *
- * Action:
- * Schedule the deletion of all key material
- * upon failure to complete a handshake
- */
-func (peer *Peer) TimerEphemeralKeyCreated() {
-       peer.event.ephemeralKeyCreated.Fire()
-       // peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
+func (timer *Timer) Del() {
+       timer.isPending = false
+       timer.timer.Stop()
 }
 
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
-       // create initiation message
-
-       msg, err := peer.device.CreateMessageInitiation(peer)
-       if err != nil {
-               return err
-       }
+func (peer *Peer) timersActive() bool {
+       return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
+}
 
-       // marshal handshake message
+func expiredRetransmitHandshake(peer *Peer) {
+       if peer.timers.handshakeAttempts > MaxTimerHandshakes {
+               peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
 
-       var buff [MessageInitiationSize]byte
-       writer := bytes.NewBuffer(buff[:0])
-       binary.Write(writer, binary.LittleEndian, msg)
-       packet := writer.Bytes()
-       peer.mac.AddMacs(packet)
+               if peer.timersActive() {
+                       peer.timers.sendKeepalive.Del()
+               }
 
-       // send to endpoint
+               /* We drop all packets without a keypair and don't try again,
+                * if we try unsuccessfully for too long to make a handshake.
+                */
+               peer.FlushNonceQueue()
 
-       peer.event.anyAuthenticatedPacketTraversal.Fire()
+               /* We set a timer for destroying any residue that might be left
+                * of a partial exchange.
+                */
+               if peer.timersActive() && !peer.timers.zeroKeyMaterial.isPending {
+                       peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+               }
+       } else {
+               peer.timers.handshakeAttempts++
+               peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts+1)
+
+               /* We clear the endpoint address src address, in case this is the cause of trouble. */
+               peer.mutex.Lock()
+               if peer.endpoint != nil {
+                       peer.endpoint.ClearSrc()
+               }
+               peer.mutex.Unlock()
 
-       return peer.SendBuffer(packet)
+               peer.SendHandshakeInitiation(true)
+       }
 }
 
-func newTimer() *time.Timer {
-       timer := time.NewTimer(time.Hour)
-       timer.Stop()
-       return timer
+func expiredSendKeepalive(peer *Peer) {
+       peer.SendKeepalive()
+       if peer.timers.needAnotherKeepalive {
+               peer.timers.needAnotherKeepalive = false
+               if peer.timersActive() {
+                       peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+               }
+       }
 }
 
-func (peer *Peer) RoutineTimerHandler() {
-
-       device := peer.device
-
-       logInfo := device.log.Info
-       logDebug := device.log.Debug
-
-       defer func() {
-               logDebug.Println(peer, ": Routine: timer handler - stopped")
-               peer.routines.stopping.Done()
-       }()
-
-       logDebug.Println(peer, ": Routine: timer handler - started")
-
-       // reset all timers
-
-       enableHandshake := true
-       pendingHandshakeNew := false
-       pendingKeepalivePassive := false
-       needAnotherKeepalive := false
-
-       timerKeepalivePassive := newTimer()
-       timerHandshakeDeadline := newTimer()
-       timerHandshakeTimeout := newTimer()
-       timerHandshakeNew := newTimer()
-       timerZeroAllKeys := newTimer()
-       timerKeepalivePersistent := newTimer()
-
-       interval := peer.persistentKeepaliveInterval
-       if interval > 0 {
-               duration := time.Duration(interval) * time.Second
-               timerKeepalivePersistent.Reset(duration)
+func expiredNewHandshake(peer *Peer) {
+       peer.device.log.Debug.Printf("%s: Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
+       /* We clear the endpoint address src address, in case this is the cause of trouble. */
+       peer.mutex.Lock()
+       if peer.endpoint != nil {
+               peer.endpoint.ClearSrc()
        }
+       peer.mutex.Unlock()
+       peer.SendHandshakeInitiation(false)
 
-       // signal synchronised setup complete
-
-       peer.routines.starting.Done()
-
-       // handle timer events
-
-       for {
-               select {
-
-               /* stopping */
-
-               case <-peer.routines.stop:
-                       return
-
-               /* events */
-
-               case <-peer.event.dataSent.C:
-                       timerKeepalivePassive.Stop()
-                       if !pendingHandshakeNew {
-                               timerHandshakeNew.Reset(NewHandshakeTime)
-                       }
-
-               case <-peer.event.dataReceived.C:
-                       if pendingKeepalivePassive {
-                               needAnotherKeepalive = true
-                       } else {
-                               timerKeepalivePassive.Reset(KeepaliveTimeout)
-                       }
-
-               case <-peer.event.anyAuthenticatedPacketTraversal.C:
-                       interval := peer.persistentKeepaliveInterval
-                       if interval > 0 {
-                               duration := time.Duration(interval) * time.Second
-                               timerKeepalivePersistent.Reset(duration)
-                       }
-
-               case <-peer.event.handshakeBegin.C:
-
-                       if !enableHandshake {
-                               continue
-                       }
-
-                       logDebug.Println(peer, ": Event, Handshake Begin")
-
-                       err := peer.sendNewHandshake()
-
-                       // set timeout
-
-                       jitter := time.Millisecond * time.Duration(rand.Int31n(334))
-                       timerKeepalivePassive.Stop()
-                       timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
-
-                       if err != nil {
-                               logInfo.Println(peer, ": Failed to send handshake initiation", err)
-                       } else {
-                               logDebug.Println(peer, ": Send handshake initiation (initial)")
-                       }
-
-                       timerHandshakeDeadline.Reset(RekeyAttemptTime)
-
-                       // disable further handshakes
-
-                       peer.event.handshakeBegin.Clear()
-                       enableHandshake = false
-
-               case <-peer.event.handshakeCompleted.C:
-
-                       logInfo.Println(peer, ": Handshake completed")
-
-                       atomic.StoreInt64(
-                               &peer.stats.lastHandshakeNano,
-                               time.Now().UnixNano(),
-                       )
-
-                       timerHandshakeTimeout.Stop()
-                       timerHandshakeDeadline.Stop()
-                       peer.timer.sendLastMinuteHandshake.Set(false)
-
-                       // allow further handshakes
-
-                       peer.event.handshakeBegin.Clear()
-                       enableHandshake = true
-
-               /* timers */
-
-               case <-timerKeepalivePersistent.C:
-
-                       interval := peer.persistentKeepaliveInterval
-                       if interval > 0 {
-                               logDebug.Println(peer, ": Send keep-alive (persistent)")
-                               timerKeepalivePassive.Stop()
-                               peer.SendKeepAlive()
-                       }
-
-               case <-timerKeepalivePassive.C:
-
-                       logDebug.Println(peer, ": Send keep-alive (passive)")
-
-                       peer.SendKeepAlive()
-
-                       if needAnotherKeepalive {
-                               timerKeepalivePassive.Reset(KeepaliveTimeout)
-                               needAnotherKeepalive = false
-                       }
-
-               case <-timerZeroAllKeys.C:
-
-                       logDebug.Println(peer, ": Clear all key-material (timer event)")
-
-                       hs := &peer.handshake
-                       hs.mutex.Lock()
-
-                       kp := &peer.keyPairs
-                       kp.mutex.Lock()
-
-                       // remove key-pairs
-
-                       if kp.previous != nil {
-                               device.DeleteKeyPair(kp.previous)
-                               kp.previous = nil
-                       }
-                       if kp.current != nil {
-                               device.DeleteKeyPair(kp.current)
-                               kp.current = nil
-                       }
-                       if kp.next != nil {
-                               device.DeleteKeyPair(kp.next)
-                               kp.next = nil
-                       }
-                       kp.mutex.Unlock()
-
-                       // zero out handshake
-
-                       device.indices.Delete(hs.localIndex)
-                       hs.Clear()
-                       hs.mutex.Unlock()
-
-               case <-timerHandshakeTimeout.C:
-
-                       // allow new handshake to be send
+}
 
-                       enableHandshake = true
+func expiredZeroKeyMaterial(peer *Peer) {
+       peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
 
-                       // clear source (in case this is causing problems)
+       hs := &peer.handshake
+       hs.mutex.Lock()
 
-                       peer.mutex.Lock()
-                       if peer.endpoint != nil {
-                               peer.endpoint.ClearSrc()
-                       }
-                       peer.mutex.Unlock()
+       kp := &peer.keyPairs
+       kp.mutex.Lock()
 
-                       // send new handshake
+       if kp.previous != nil {
+               peer.device.DeleteKeypair(kp.previous)
+               kp.previous = nil
+       }
+       if kp.current != nil {
+               peer.device.DeleteKeypair(kp.current)
+               kp.current = nil
+       }
+       if kp.next != nil {
+               peer.device.DeleteKeypair(kp.next)
+               kp.next = nil
+       }
+       kp.mutex.Unlock()
 
-                       err := peer.sendNewHandshake()
+       peer.device.indices.Delete(hs.localIndex)
+       hs.Clear()
+       hs.mutex.Unlock()
+}
 
-                       // set timeout
+func expiredPersistentKeepalive(peer *Peer) {
+       if peer.persistentKeepaliveInterval > 0 {
+               if peer.timersActive() {
+                       peer.timers.sendKeepalive.Del()
+               }
+               peer.SendKeepalive()
+       }
+}
 
-                       jitter := time.Millisecond * time.Duration(rand.Int31n(334))
-                       timerKeepalivePassive.Stop()
-                       timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
+/* Should be called after an authenticated data packet is sent. */
+func (peer *Peer) timersDataSent() {
+       if peer.timersActive() {
+               peer.timers.sendKeepalive.Del()
+       }
 
-                       if err != nil {
-                               logInfo.Println(peer, ": Failed to send handshake initiation", err)
-                       } else {
-                               logDebug.Println(peer, ": Send handshake initiation (subsequent)")
-                       }
+       if peer.timersActive() && !peer.timers.newHandshake.isPending {
+               peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
+       }
+}
 
-                       // disable further handshakes
+/* Should be called after an authenticated data packet is received. */
+func (peer *Peer) timersDataReceived() {
+       if peer.timersActive() {
+               if !peer.timers.sendKeepalive.isPending {
+                       peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+               } else {
+                       peer.timers.needAnotherKeepalive = true
+               }
+       }
+}
 
-                       peer.event.handshakeBegin.Clear()
-                       enableHandshake = false
+/* Should be called after any type of authenticated packet is received -- keepalive or data. */
+func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
+       if peer.timersActive() {
+               peer.timers.newHandshake.Del()
+       }
+}
 
-               case <-timerHandshakeDeadline.C:
+/* Should be called after a handshake initiation message is sent. */
+func (peer *Peer) timersHandshakeInitiated() {
+       if peer.timersActive() {
+               peer.timers.sendKeepalive.Del()
+               peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+       }
+}
 
-                       // clear all queued packets and stop keep-alive
+/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
+func (peer *Peer) timersHandshakeComplete() {
+       if peer.timersActive() {
+               peer.timers.retransmitHandshake.Del()
+       }
+       peer.timers.handshakeAttempts = 0
+       peer.timers.sentLastMinuteHandshake = false
+       atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+}
 
-                       logInfo.Println(peer, ": Handshake negotiation timed-out")
+/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
+func (peer *Peer) timersSessionDerived() {
+       if peer.timersActive() {
+               peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+       }
+}
 
-                       peer.flushNonceQueue()
-                       peer.event.flushNonceQueue.Fire()
+/* Should be called before a packet with authentication -- data, keepalive, either handshake -- is sent, or after one is received. */
+func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
+       if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
+               peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+       }
+}
 
-                       // renable further handshakes
+func (peer *Peer) timersInit() {
+       peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
+       peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
+       peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
+       peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
+       peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
+       peer.timers.handshakeAttempts = 0
+       peer.timers.sentLastMinuteHandshake = false
+       peer.timers.needAnotherKeepalive = false
+       peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+}
 
-                       peer.event.handshakeBegin.Clear()
-                       enableHandshake = true
-               }
-       }
+func (peer *Peer) timersStop() {
+       peer.timers.retransmitHandshake.Del()
+       peer.timers.sendKeepalive.Del()
+       peer.timers.newHandshake.Del()
+       peer.timers.zeroKeyMaterial.Del()
+       peer.timers.persistentKeepalive.Del()
 }
diff --git a/uapi.go b/uapi.go
index 54d9bae80ac123550b709dd6e5e33cd7ef5bbdcd..4b2038b5cc0f599741445becbfa0d8e7df86e616 100644 (file)
--- a/uapi.go
+++ b/uapi.go
@@ -256,8 +256,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logDebug.Println("UAPI: Created new peer:", peer)
                                }
 
-                               peer.event.handshakePushDeadline.Fire()
-
                        case "remove":
 
                                // remove currently selected peer from device
@@ -288,8 +286,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
 
-                               peer.event.handshakePushDeadline.Fire()
-
                        case "endpoint":
 
                                // set endpoint destination
@@ -304,7 +300,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                                return err
                                        }
                                        peer.endpoint = endpoint
-                                       peer.event.handshakePushDeadline.Fire()
                                        return nil
                                }()
 
@@ -315,7 +310,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "persistent_keepalive_interval":
 
-                               // update keep-alive interval
+                               // update persistent keepalive interval
 
                                logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer)
 
@@ -328,7 +323,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                old := peer.persistentKeepaliveInterval
                                peer.persistentKeepaliveInterval = uint16(secs)
 
-                               // send immediate keep-alive
+                               // send immediate keepalive if we're turning it on and before it wasn't on
 
                                if old == 0 && secs != 0 {
                                        if err != nil {
@@ -336,7 +331,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                                return &IPCError{Code: ipcErrorIO}
                                        }
                                        if device.isUp.Get() && !dummy {
-                                               peer.SendKeepAlive()
+                                               peer.SendKeepalive()
                                        }
                                }