// open new sockets
- if device.tun.isUp.Get() {
+ if device.isUp.Get() {
device.log.Debug.Println("UDP bind updating")
)
type Device struct {
- closed AtomicBool // device is closed? (acting as guard)
+ isUp AtomicBool // device is up (TUN interface up)?
+ isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
- isUp AtomicBool
mtu int32
}
pool struct {
mac CookieChecker
}
+func (device *Device) Up() {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ device.isUp.Set(true)
+ updateBind(device)
+ for _, peer := range device.peers {
+ peer.Start()
+ }
+}
+
+func (device *Device) Down() {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ device.isUp.Set(false)
+ closeBind(device)
+ for _, peer := range device.peers {
+ peer.Stop()
+ }
+}
+
/* Warning:
* The caller must hold the device mutex (write lock)
*/
return
}
peer.mutex.Lock()
+ peer.Stop()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
- peer.Close()
}
func (device *Device) IsUnderLoad() bool {
device.publicKey = publicKey
device.mac.Init(publicKey)
- // do DH precomputations
+ // do DH pre-computations
rmKey := device.privateKey.IsZero()
device.mutex.Lock()
defer device.mutex.Unlock()
+ device.isUp.Set(false)
+ device.isClosed.Set(false)
+
device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
- device.tun.isUp.Set(false)
device.indices.Init()
device.ratelimiter.Init()
func (device *Device) RemoveAllPeers() {
device.mutex.Lock()
defer device.mutex.Unlock()
-
- for key, peer := range device.peers {
- peer.mutex.Lock()
- delete(device.peers, key)
- peer.Close()
- peer.mutex.Unlock()
+ for key := range device.peers {
+ removePeerUnsafe(device, key)
}
}
func (device *Device) Close() {
- if device.closed.Swap(true) {
+ if device.isClosed.Swap(true) {
return
}
device.log.Info.Println("Closing device")
flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
- stop Signal // size 0, stop all goroutines
+ stop Signal // size 0, stop all goroutines in peer
}
timer struct {
// state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives
keepalivePassive Timer // set upon recieving messages
- newHandshake Timer // begin a new handshake (stale)
zeroAllKeys Timer // zero all key material
+ handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer()
- peer.timer.newHandshake = NewTimer()
peer.timer.zeroAllKeys = NewTimer()
+ peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
// prepare signaling & routines
- peer.signal.stop = NewSignal()
peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
- go peer.RoutineNonce()
- go peer.RoutineTimerHandler()
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
-
return peer, nil
}
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
+
peer.mutex.RLock()
defer peer.mutex.RUnlock()
+
if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
+
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
-/* Returns a short string identification for logging
+/* Returns a short string identifier for logging
*/
func (peer *Peer) String() string {
if peer.endpoint == nil {
)
}
-func (peer *Peer) Close() {
+/* Starts all routines for a given peer
+ *
+ * Requires that the caller holds the exclusive peer lock!
+ */
+func unsafePeerStart(peer *Peer) {
+ peer.signal.stop.Broadcast()
+ peer.signal.stop = NewSignal()
+
+ var wait sync.WaitGroup
+
+ wait.Add(1)
+
+ go peer.RoutineNonce()
+ go peer.RoutineTimerHandler(&wait)
+ go peer.RoutineSequentialSender()
+ go peer.RoutineSequentialReceiver()
+
+ wait.Wait()
+}
+
+func (peer *Peer) Start() {
+ peer.mutex.Lock()
+ unsafePeerStart(peer)
+ peer.mutex.Unlock()
+}
+
+func (peer *Peer) Stop() {
peer.signal.stop.Broadcast()
}
t.Start(dur)
}
-func (t *Timer) Push(dur time.Duration) {
- if t.pending.Get() {
- t.Reset(dur)
- }
-}
-
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
"bytes"\r
"encoding/binary"\r
"math/rand"\r
+ "sync"\r
"sync/atomic"\r
"time"\r
)\r
\r
+/* NOTE:\r
+ * Notion of validity\r
+ *\r
+ *\r
+ */\r
+\r
/* Called when a new authenticated message has been send\r
*\r
*/\r
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving\r
if send {\r
// do a last minute attempt at initiating a new handshake\r
- peer.signal.handshakeBegin.Send()\r
peer.timer.sendLastMinuteHandshake = true\r
+ peer.signal.handshakeBegin.Send()\r
}\r
}\r
\r
/* Queues a keep-alive if no packets are queued for peer\r
*/\r
func (peer *Peer) SendKeepAlive() bool {\r
+ if len(peer.queue.nonce) != 0 {\r
+ return false\r
+ }\r
elem := peer.device.NewOutboundElement()\r
elem.packet = nil\r
- if len(peer.queue.nonce) == 0 {\r
- select {\r
- case peer.queue.nonce <- elem:\r
- return true\r
- default:\r
- return false\r
- }\r
+ select {\r
+ case peer.queue.nonce <- elem:\r
+ return true\r
+ default:\r
+ return false\r
}\r
- return true\r
}\r
\r
/* Event:\r
*/\r
func (peer *Peer) TimerDataSent() {\r
peer.timer.keepalivePassive.Stop()\r
- if peer.timer.newHandshake.Pending() {\r
- peer.timer.newHandshake.Reset(NewHandshakeTime)\r
- }\r
+ peer.timer.handshakeNew.Start(NewHandshakeTime)\r
}\r
\r
/* Event:\r
* Any (authenticated) packet received\r
*/\r
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {\r
- peer.timer.newHandshake.Stop()\r
+ peer.timer.handshakeNew.Stop()\r
}\r
\r
/* Event:\r
* - First transport message under the "next" key\r
*/\r
func (peer *Peer) TimerHandshakeComplete() {\r
- atomic.StoreInt64(\r
- &peer.stats.lastHandshakeNano,\r
- time.Now().UnixNano(),\r
- )\r
peer.signal.handshakeCompleted.Send()\r
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())\r
}\r
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)\r
}\r
\r
-func (peer *Peer) RoutineTimerHandler() {\r
+/* Sends a new handshake initiation message to the peer (endpoint)\r
+ */\r
+func (peer *Peer) sendNewHandshake() error {\r
+\r
+ // temporarily disable the handshake complete signal\r
+\r
+ peer.signal.handshakeCompleted.Disable()\r
+\r
+ // create initiation message\r
+\r
+ msg, err := peer.device.CreateMessageInitiation(peer)\r
+ if err != nil {\r
+ return err\r
+ }\r
+\r
+ // marshal handshake message\r
+\r
+ var buff [MessageInitiationSize]byte\r
+ writer := bytes.NewBuffer(buff[:0])\r
+ binary.Write(writer, binary.LittleEndian, msg)\r
+ packet := writer.Bytes()\r
+ peer.mac.AddMacs(packet)\r
+\r
+ // send to endpoint\r
+\r
+ peer.TimerAnyAuthenticatedPacketTraversal()\r
+\r
+ err = peer.SendBuffer(packet)\r
+ if err == nil {\r
+ peer.signal.handshakeCompleted.Enable()\r
+ }\r
+\r
+ // set timeout\r
+\r
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)\r
+\r
+ peer.timer.keepalivePassive.Stop()\r
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)\r
+\r
+ return err\r
+}\r
+\r
+func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {\r
device := peer.device\r
\r
logInfo := device.log.Info\r
logDebug := device.log.Debug\r
logDebug.Println("Routine, timer handler, started for peer", peer.String())\r
\r
+ // reset all timers\r
+\r
+ peer.timer.keepalivePassive.Stop()\r
+ peer.timer.handshakeDeadline.Stop()\r
+ peer.timer.handshakeTimeout.Stop()\r
+ peer.timer.handshakeNew.Stop()\r
+ peer.timer.zeroAllKeys.Stop()\r
+\r
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
+ if interval > 0 {\r
+ duration := time.Duration(interval) * time.Second\r
+ peer.timer.keepalivePersistent.Reset(duration)\r
+ }\r
+\r
+ // signal that timers are reset\r
+\r
+ ready.Done()\r
+\r
+ // handle timer events\r
+\r
for {\r
select {\r
\r
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)\r
if interval > 0 {\r
logDebug.Println("Sending keep-alive to", peer.String())\r
+ peer.timer.keepalivePassive.Stop()\r
peer.SendKeepAlive()\r
}\r
\r
peer.SendKeepAlive()\r
\r
if peer.timer.needAnotherKeepalive {\r
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
peer.timer.needAnotherKeepalive = false\r
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)\r
}\r
\r
// clear key material timer\r
\r
// handshake timers\r
\r
- case <-peer.timer.newHandshake.Wait():\r
+ case <-peer.timer.handshakeNew.Wait():\r
logInfo.Println("Retrying handshake with", peer.String())\r
peer.signal.handshakeBegin.Send()\r
\r
logInfo.Println(\r
"Handshake completed for:", peer.String())\r
\r
+ atomic.StoreInt64(\r
+ &peer.stats.lastHandshakeNano,\r
+ time.Now().UnixNano(),\r
+ )\r
+\r
peer.timer.handshakeTimeout.Stop()\r
peer.timer.handshakeDeadline.Stop()\r
peer.signal.handshakeBegin.Enable()\r
- }\r
- }\r
-}\r
-\r
-/* Sends a new handshake initiation message to the peer (endpoint)\r
- */\r
-func (peer *Peer) sendNewHandshake() error {\r
-\r
- // temporarily disable the handshake complete signal\r
-\r
- peer.signal.handshakeCompleted.Disable()\r
-\r
- // create initiation message\r
\r
- msg, err := peer.device.CreateMessageInitiation(peer)\r
- if err != nil {\r
- return err\r
- }\r
-\r
- // marshal handshake message\r
-\r
- var buff [MessageInitiationSize]byte\r
- writer := bytes.NewBuffer(buff[:0])\r
- binary.Write(writer, binary.LittleEndian, msg)\r
- packet := writer.Bytes()\r
- peer.mac.AddMacs(packet)\r
-\r
- // send to endpoint\r
-\r
- err = peer.SendBuffer(packet)\r
- if err == nil {\r
- peer.TimerAnyAuthenticatedPacketTraversal()\r
- peer.signal.handshakeCompleted.Enable()\r
+ peer.timer.sendLastMinuteHandshake = false\r
+ }\r
}\r
-\r
- // set timeout\r
-\r
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)\r
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)\r
-\r
- return err\r
}\r
}
if event&TUNEventUp != 0 {
- if !device.tun.isUp.Get() {
- // begin listening for incomming datagrams
- logInfo.Println("Interface set up")
- device.tun.isUp.Set(true)
- updateBind(device)
- }
+ logInfo.Println("Interface set up")
+ device.Up()
}
if event&TUNEventDown != 0 {
- if device.tun.isUp.Get() {
- // stop listening for incomming datagrams
- logInfo.Println("Interface set down")
- device.tun.isUp.Set(false)
- closeBind(device)
- }
+ logInfo.Println("Interface set down")
+ device.Up()
}
}
}
logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
- if device.tun.isUp.Get() && !dummy {
+ if device.isUp.Get() && !dummy {
peer.SendKeepAlive()
}
}