]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
More refactoring
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 21:14:43 +0000 (23:14 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 21:20:13 +0000 (23:20 +0200)
conn.go
device.go
noise-protocol.go
noise_test.go
peer.go
receive.go
send.go
timers.go
uapi.go

diff --git a/conn.go b/conn.go
index 082bbca94f2e08c63411fabaf75c2f8882cc5846..4b347ec1b042dee039491e6f58d4617b493fb1c1 100644 (file)
--- a/conn.go
+++ b/conn.go
@@ -74,9 +74,6 @@ func (device *Device) BindSetMark(mark uint32) error {
        device.net.mutex.Lock()
        defer device.net.mutex.Unlock()
 
-       device.peers.mutex.Lock()
-       defer device.peers.mutex.Unlock()
-
        // check if modified
 
        if device.net.fwmark == mark {
@@ -92,6 +89,18 @@ func (device *Device) BindSetMark(mark uint32) error {
                }
        }
 
+       // clear cached source addresses
+
+       device.peers.mutex.RLock()
+       for _, peer := range device.peers.keyMap {
+               peer.mutex.Lock()
+               defer peer.mutex.Unlock()
+               if peer.endpoint != nil {
+                       peer.endpoint.ClearSrc()
+               }
+       }
+       device.peers.mutex.RUnlock()
+
        return nil
 }
 
@@ -100,9 +109,6 @@ func (device *Device) BindUpdate() error {
        device.net.mutex.Lock()
        defer device.net.mutex.Unlock()
 
-       device.peers.mutex.Lock()
-       defer device.peers.mutex.Unlock()
-
        // close existing sockets
 
        if err := unsafeCloseBind(device); err != nil {
@@ -135,6 +141,7 @@ func (device *Device) BindUpdate() error {
 
                // clear cached source addresses
 
+               device.peers.mutex.RLock()
                for _, peer := range device.peers.keyMap {
                        peer.mutex.Lock()
                        defer peer.mutex.Unlock()
@@ -142,6 +149,7 @@ func (device *Device) BindUpdate() error {
                                peer.endpoint.ClearSrc()
                        }
                }
+               device.peers.mutex.RUnlock()
 
                // start receiving routines
 
index 34af419048d8a5a425a65a170ffd4783066ee7e6..cc12ac9d8d6f42150ded35c8d20afa472bac63c9 100644 (file)
--- a/device.go
+++ b/device.go
@@ -38,17 +38,12 @@ type Device struct {
                fwmark uint32 // mark value (0 = disabled)
        }
 
-       noise struct {
+       staticIdentity struct {
                mutex      sync.RWMutex
                privateKey NoisePrivateKey
                publicKey  NoisePublicKey
        }
 
-       routing struct {
-               mutex sync.RWMutex
-               table AllowedIPs
-       }
-
        peers struct {
                mutex  sync.RWMutex
                keyMap map[NoisePublicKey]*Peer
@@ -56,8 +51,9 @@ type Device struct {
 
        // unprotected / "self-synchronising resources"
 
-       indexTable IndexTable
-       mac        CookieChecker
+       allowedips    AllowedIPs
+       indexTable    IndexTable
+       cookieChecker CookieChecker
 
        rate struct {
                underLoadUntil atomic.Value
@@ -87,15 +83,13 @@ type Device struct {
 /* Converts the peer into a "zombie", which remains in the peer map,
  * but processes no packets and does not exists in the routing table.
  *
- * Must hold:
- *  device.peers.mutex : exclusive lock
- *  device.routing     : exclusive lock
+ * Must hold device.peers.mutex.
  */
 func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
 
        // stop routing and processing of packets
 
-       device.routing.table.RemoveByPeer(peer)
+       device.allowedips.RemoveByPeer(peer)
        peer.Stop()
 
        // remove from peer map
@@ -131,19 +125,19 @@ func deviceUpdateState(device *Device) {
                        device.isUp.Set(false)
                        break
                }
-               device.peers.mutex.Lock()
+               device.peers.mutex.RLock()
                for _, peer := range device.peers.keyMap {
                        peer.Start()
                }
-               device.peers.mutex.Unlock()
+               device.peers.mutex.RUnlock()
 
        case false:
                device.BindClose()
-               device.peers.mutex.Lock()
+               device.peers.mutex.RLock()
                for _, peer := range device.peers.keyMap {
                        peer.Stop()
                }
-               device.peers.mutex.Unlock()
+               device.peers.mutex.RUnlock()
        }
 
        // update state variables
@@ -199,11 +193,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
        // lock required resources
 
-       device.noise.mutex.Lock()
-       defer device.noise.mutex.Unlock()
-
-       device.routing.mutex.Lock()
-       defer device.routing.mutex.Unlock()
+       device.staticIdentity.mutex.Lock()
+       defer device.staticIdentity.mutex.Unlock()
 
        device.peers.mutex.Lock()
        defer device.peers.mutex.Unlock()
@@ -224,13 +215,13 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
        // update key material
 
-       device.noise.privateKey = sk
-       device.noise.publicKey = publicKey
-       device.mac.Init(publicKey)
+       device.staticIdentity.privateKey = sk
+       device.staticIdentity.publicKey = publicKey
+       device.cookieChecker.Init(publicKey)
 
        // do static-static DH pre-computations
 
-       rmKey := device.noise.privateKey.IsZero()
+       rmKey := device.staticIdentity.privateKey.IsZero()
 
        for key, peer := range device.peers.keyMap {
 
@@ -239,7 +230,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
                if rmKey {
                        hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
                } else {
-                       hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
+                       hs.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(hs.remoteStatic)
                }
 
                if isZero(hs.precomputedStaticStatic[:]) {
@@ -281,10 +272,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
        device.rate.limiter.Init()
        device.rate.underLoadUntil.Store(time.Time{})
 
-       // initialize noise & crypt-key routine
+       // initialize staticIdentity & crypt-key routine
 
        device.indexTable.Init()
-       device.routing.table.Reset()
+       device.allowedips.Reset()
 
        // setup buffer pool
 
@@ -333,12 +324,6 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
 }
 
 func (device *Device) RemovePeer(key NoisePublicKey) {
-       device.noise.mutex.Lock()
-       defer device.noise.mutex.Unlock()
-
-       device.routing.mutex.Lock()
-       defer device.routing.mutex.Unlock()
-
        device.peers.mutex.Lock()
        defer device.peers.mutex.Unlock()
 
@@ -351,12 +336,6 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
 }
 
 func (device *Device) RemoveAllPeers() {
-       device.noise.mutex.Lock()
-       defer device.noise.mutex.Unlock()
-
-       device.routing.mutex.Lock()
-       defer device.routing.mutex.Unlock()
-
        device.peers.mutex.Lock()
        defer device.peers.mutex.Unlock()
 
index f72dcc480cf2c839aff531ae7d54918cbf0c367b..ffc2b50e9f04159a35587a9703b0bbd974fe5dcb 100644 (file)
@@ -107,6 +107,7 @@ type Handshake struct {
        precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
        lastTimestamp             tai64n.Timestamp
        lastInitiationConsumption time.Time
+       lastSentHandshake         time.Time
 }
 
 var (
@@ -153,8 +154,8 @@ func init() {
 
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
 
-       device.noise.mutex.RLock()
-       defer device.noise.mutex.RUnlock()
+       device.staticIdentity.mutex.RLock()
+       defer device.staticIdentity.mutex.RUnlock()
 
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -206,7 +207,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
                        ss[:],
                )
                aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
+               aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
        }()
        handshake.mixHash(msg.Static[:])
 
@@ -240,10 +241,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
                return nil
        }
 
-       device.noise.mutex.RLock()
-       defer device.noise.mutex.RUnlock()
+       device.staticIdentity.mutex.RLock()
+       defer device.staticIdentity.mutex.RUnlock()
 
-       mixHash(&hash, &InitialHash, device.noise.publicKey[:])
+       mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
        mixHash(&hash, &hash, msg.Ephemeral[:])
        mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
 
@@ -253,7 +254,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        var peerPK NoisePublicKey
        func() {
                var key [chacha20poly1305.KeySize]byte
-               ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
+               ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
                KDF2(&chainKey, &key, chainKey[:], ss[:])
                aead, _ := chacha20poly1305.New(key[:])
                _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@@ -422,8 +423,8 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 
                // lock private key for reading
 
-               device.noise.mutex.RLock()
-               defer device.noise.mutex.RUnlock()
+               device.staticIdentity.mutex.RLock()
+               defer device.staticIdentity.mutex.RUnlock()
 
                // finish 3-way DH
 
@@ -437,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                }()
 
                func() {
-                       ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
+                       ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
@@ -490,7 +491,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 /* Derives a new keypair from the current handshake state
  *
  */
-func (peer *Peer) DeriveNewKeypair() error {
+func (peer *Peer) BeginSymmetricSession() error {
        device := peer.device
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -552,50 +553,48 @@ func (peer *Peer) DeriveNewKeypair() error {
 
        // rotate key pairs
 
-       kp := &peer.keypairs
-       kp.mutex.Lock()
+       keypairs := &peer.keypairs
+       keypairs.mutex.Lock()
+       defer keypairs.mutex.Unlock()
 
-       peer.timersSessionDerived()
-
-       previous := kp.previous
-       next := kp.next
-       current := kp.current
+       previous := keypairs.previous
+       next := keypairs.next
+       current := keypairs.current
 
        if isInitiator {
                if next != nil {
-                       kp.next = nil
-                       kp.previous = next
+                       keypairs.next = nil
+                       keypairs.previous = next
                        device.DeleteKeypair(current)
                } else {
-                       kp.previous = current
+                       keypairs.previous = current
                }
                device.DeleteKeypair(previous)
-               kp.current = keypair
+               keypairs.current = keypair
        } else {
-               kp.next = keypair
+               keypairs.next = keypair
                device.DeleteKeypair(next)
-               kp.previous = nil
+               keypairs.previous = nil
                device.DeleteKeypair(previous)
        }
-       kp.mutex.Unlock()
 
        return nil
 }
 
 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
-       kp := &peer.keypairs
-       if kp.next != receivedKeypair {
+       keypairs := &peer.keypairs
+       if keypairs.next != receivedKeypair {
                return false
        }
-       kp.mutex.Lock()
-       defer kp.mutex.Unlock()
-       if kp.next != receivedKeypair {
+       keypairs.mutex.Lock()
+       defer keypairs.mutex.Unlock()
+       if keypairs.next != receivedKeypair {
                return false
        }
-       old := kp.previous
-       kp.previous = kp.current
+       old := keypairs.previous
+       keypairs.previous = keypairs.current
        peer.device.DeleteKeypair(old)
-       kp.current = kp.next
-       kp.next = nil
+       keypairs.current = keypairs.next
+       keypairs.next = nil
        return true
 }
index ce32097fd6896be71ab21a4b98ae41b648dd6737..8e1bd89f71e2b98e31201bf9759a2572ee6ecb48 100644 (file)
@@ -36,8 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
        defer dev1.Close()
        defer dev2.Close()
 
-       peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey())
-       peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey())
+       peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
+       peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
 
        assertEqual(
                t,
@@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) {
 
        t.Log("deriving keys")
 
-       key1 := peer1.DeriveNewKeypair()
-       key2 := peer2.DeriveNewKeypair()
+       key1 := peer1.BeginSymmetricSession()
+       key2 := peer2.BeginSymmetricSession()
 
        if key1 == nil {
                t.Fatal("failed to dervice keypair for peer 1")
diff --git a/peer.go b/peer.go
index d574c7114df8da7f64ebdc48e7868c19990098e8..1151341bb79ccc74f6c0da1472b59a491c224a1d 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -19,7 +19,7 @@ const (
 
 type Peer struct {
        isRunning                   AtomicBool
-       mutex                       sync.RWMutex
+       mutex                       sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
        keypairs                    Keypairs
        handshake                   Handshake
        device                      *Device
@@ -42,7 +42,6 @@ type Peer struct {
                handshakeAttempts       uint
                needAnotherKeepalive    bool
                sentLastMinuteHandshake bool
-               lastSentHandshake       time.Time
        }
 
        signals struct {
@@ -64,7 +63,7 @@ type Peer struct {
                stop     chan struct{}  // size 0, stop all go routines in peer
        }
 
-       mac CookieGenerator
+       cookieGenerator CookieGenerator
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -75,11 +74,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        // lock resources
 
-       device.state.mutex.Lock()
-       defer device.state.mutex.Unlock()
-
-       device.noise.mutex.RLock()
-       defer device.noise.mutex.RUnlock()
+       device.staticIdentity.mutex.RLock()
+       defer device.staticIdentity.mutex.RUnlock()
 
        device.peers.mutex.Lock()
        defer device.peers.mutex.Unlock()
@@ -96,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        peer.mutex.Lock()
        defer peer.mutex.Unlock()
 
-       peer.mac.Init(pk)
+       peer.cookieGenerator.Init(pk)
        peer.device = device
        peer.isRunning.Set(false)
 
@@ -113,7 +109,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        handshake := &peer.handshake
        handshake.mutex.Lock()
        handshake.remoteStatic = pk
-       handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
+       handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
        handshake.mutex.Unlock()
 
        // reset endpoint
@@ -191,6 +187,7 @@ func (peer *Peer) Start() {
        peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
 
        peer.timersInit()
+       peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
        peer.signals.newKeypairArrived = make(chan struct{}, 1)
        peer.signals.flushNonceQueue = make(chan struct{}, 1)
 
@@ -204,6 +201,32 @@ func (peer *Peer) Start() {
        peer.isRunning.Set(true)
 }
 
+func (peer *Peer) ZeroAndFlushAll() {
+       device := peer.device
+
+       // clear key pairs
+
+       keypairs := &peer.keypairs
+       keypairs.mutex.Lock()
+       device.DeleteKeypair(keypairs.previous)
+       device.DeleteKeypair(keypairs.current)
+       device.DeleteKeypair(keypairs.next)
+       keypairs.previous = nil
+       keypairs.current = nil
+       keypairs.next = nil
+       keypairs.mutex.Unlock()
+
+       // clear handshake state
+
+       handshake := &peer.handshake
+       handshake.mutex.Lock()
+       device.indexTable.Delete(handshake.localIndex)
+       handshake.Clear()
+       handshake.mutex.Unlock()
+
+       peer.FlushNonceQueue()
+}
+
 func (peer *Peer) Stop() {
 
        // prevent simultaneous start/stop operations
@@ -215,8 +238,7 @@ func (peer *Peer) Stop() {
                return
        }
 
-       device := peer.device
-       device.log.Debug.Println(peer, ": Stopping...")
+       peer.device.log.Debug.Println(peer, ": Stopping...")
 
        peer.timersStop()
 
@@ -232,27 +254,5 @@ func (peer *Peer) Stop() {
        close(peer.queue.outbound)
        close(peer.queue.inbound)
 
-       // clear key pairs
-
-       kp := &peer.keypairs
-       kp.mutex.Lock()
-
-       device.DeleteKeypair(kp.previous)
-       device.DeleteKeypair(kp.current)
-       device.DeleteKeypair(kp.next)
-
-       kp.previous = nil
-       kp.current = nil
-       kp.next = nil
-       kp.mutex.Unlock()
-
-       // clear handshake state
-
-       hs := &peer.handshake
-       hs.mutex.Lock()
-       device.indexTable.Delete(hs.localIndex)
-       hs.Clear()
-       hs.mutex.Unlock()
-
-       peer.FlushNonceQueue()
+       peer.ZeroAndFlushAll()
 }
index 64253e6eeadeb7da0e63013f401556665752a094..77062fa6a2c8f1dde1ee6b546ae005901102ba4d 100644 (file)
@@ -107,8 +107,8 @@ func (peer *Peer) keepKeyFreshReceiving() {
        if peer.timers.sentLastMinuteHandshake {
                return
        }
-       kp := peer.keypairs.Current()
-       if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
+       keypair := peer.keypairs.Current()
+       if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
                peer.timers.sentLastMinuteHandshake = true
                peer.SendHandshakeInitiation(false)
        }
@@ -325,7 +325,6 @@ func (device *Device) RoutineHandshake() {
 
        logDebug.Println("Routine: handshake worker - started")
 
-       var temp [MessageHandshakeSize]byte
        var elem QueueHandshakeElement
        var ok bool
 
@@ -367,52 +366,28 @@ func (device *Device) RoutineHandshake() {
                        // consume reply
 
                        if peer := entry.peer; peer.isRunning.Get() {
-                               peer.mac.ConsumeReply(&reply)
+                               peer.cookieGenerator.ConsumeReply(&reply)
                        }
 
                        continue
 
                case MessageInitiationType, MessageResponseType:
 
-                       // check mac fields and ratelimit
+                       // check mac fields and maybe ratelimit
 
-                       if !device.mac.CheckMAC1(elem.packet) {
+                       if !device.cookieChecker.CheckMAC1(elem.packet) {
                                logDebug.Println("Received packet with invalid mac1")
                                continue
                        }
 
                        // endpoints destination address is the source of the datagram
 
-                       srcBytes := elem.endpoint.DstToBytes()
-
                        if device.IsUnderLoad() {
 
                                // verify MAC2 field
 
-                               if !device.mac.CheckMAC2(elem.packet, srcBytes) {
-
-                                       // construct cookie reply
-
-                                       logDebug.Println(
-                                               "Sending cookie reply to:",
-                                               elem.endpoint.DstToString(),
-                                       )
-
-                                       sender := binary.LittleEndian.Uint32(elem.packet[4:8])
-                                       reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
-                                       if err != nil {
-                                               logError.Println("Failed to create cookie reply:", err)
-                                               continue
-                                       }
-
-                                       // marshal and send reply
-
-                                       writer := bytes.NewBuffer(temp[:0])
-                                       binary.Write(writer, binary.LittleEndian, reply)
-                                       device.net.bind.Send(writer.Bytes(), elem.endpoint)
-                                       if err != nil {
-                                               logDebug.Println("Failed to send cookie reply:", err)
-                                       }
+                               if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
+                                       device.SendHandshakeCookie(&elem)
                                        continue
                                }
 
@@ -467,34 +442,7 @@ func (device *Device) RoutineHandshake() {
 
                        logDebug.Println(peer, ": Received handshake initiation")
 
-                       // create response
-
-                       response, err := device.CreateMessageResponse(peer)
-                       if err != nil {
-                               logError.Println("Failed to create response message:", err)
-                               continue
-                       }
-
-                       if peer.DeriveNewKeypair() != nil {
-                               continue
-                       }
-
-                       logDebug.Println(peer, ": Sending handshake response")
-
-                       writer := bytes.NewBuffer(temp[:0])
-                       binary.Write(writer, binary.LittleEndian, response)
-                       packet := writer.Bytes()
-                       peer.mac.AddMacs(packet)
-
-                       // send response
-
-                       peer.timers.lastSentHandshake = time.Now()
-                       err = peer.SendBuffer(packet)
-                       if err == nil {
-                               peer.timersAnyAuthenticatedPacketTraversal()
-                       } else {
-                               logError.Println(peer, ": Failed to send handshake response", err)
-                       }
+                       peer.SendHandshakeResponse()
 
                case MessageResponseType:
 
@@ -534,10 +482,14 @@ func (device *Device) RoutineHandshake() {
 
                        // derive keypair
 
-                       if peer.DeriveNewKeypair() != nil {
+                       err = peer.BeginSymmetricSession()
+
+                       if err != nil {
+                               logError.Println(peer, ": Failed to derive keypair:", err)
                                continue
                        }
 
+                       peer.timersSessionDerived()
                        peer.timersHandshakeComplete()
                        peer.SendKeepalive()
                        select {
@@ -640,7 +592,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                // verify IPv4 source
 
                                src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                               if device.routing.table.LookupIPv4(src) != peer {
+                               if device.allowedips.LookupIPv4(src) != peer {
                                        logInfo.Println(
                                                "IPv4 packet with disallowed source address from",
                                                peer,
@@ -668,7 +620,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                // verify IPv6 source
 
                                src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                               if device.routing.table.LookupIPv6(src) != peer {
+                               if device.allowedips.LookupIPv6(src) != peer {
                                        logInfo.Println(
                                                peer,
                                                "sent packet with disallowed IPv6 source",
diff --git a/send.go b/send.go
index a8ec28cf4d36d8c3d7a7c5ad1c7c9c0f6a046d94..a670c4d1faada3724757444361423a8e848384bf 100644 (file)
--- a/send.go
+++ b/send.go
@@ -121,52 +121,114 @@ func (peer *Peer) SendKeepalive() bool {
        }
 }
 
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
        if !isRetry {
                peer.timers.handshakeAttempts = 0
        }
 
-       if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout {
+       peer.handshake.mutex.RLock()
+       if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
+               peer.handshake.mutex.RUnlock()
+               return nil
+       }
+       peer.handshake.mutex.RUnlock()
+
+       peer.handshake.mutex.Lock()
+       if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
+               peer.handshake.mutex.Unlock()
                return nil
        }
-       peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable?
+       peer.handshake.lastSentHandshake = time.Now()
+       peer.handshake.mutex.Unlock()
 
-       // create initiation message
+       peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
 
        msg, err := peer.device.CreateMessageInitiation(peer)
        if err != nil {
+               peer.device.log.Error.Println(peer, ": Failed to create initiation message:", err)
                return err
        }
 
-       peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
-
-       // marshal handshake message
-
        var buff [MessageInitiationSize]byte
        writer := bytes.NewBuffer(buff[:0])
        binary.Write(writer, binary.LittleEndian, msg)
        packet := writer.Bytes()
-       peer.mac.AddMacs(packet)
-
-       // send to endpoint
+       peer.cookieGenerator.AddMacs(packet)
 
        peer.timersAnyAuthenticatedPacketTraversal()
+
+       err = peer.SendBuffer(packet)
+       if err != nil {
+               peer.device.log.Error.Println(peer, ": Failed to send handshake initiation", err)
+       }
        peer.timersHandshakeInitiated()
-       return peer.SendBuffer(packet)
+
+       return err
+}
+
+func (peer *Peer) SendHandshakeResponse() error {
+       peer.handshake.mutex.Lock()
+       peer.handshake.lastSentHandshake = time.Now()
+       peer.handshake.mutex.Unlock()
+
+       peer.device.log.Debug.Println(peer, ": Sending handshake response")
+
+       response, err := peer.device.CreateMessageResponse(peer)
+       if err != nil {
+               peer.device.log.Error.Println(peer, ": Failed to create response message:", err)
+               return err
+       }
+
+       var buff [MessageResponseSize]byte
+       writer := bytes.NewBuffer(buff[:0])
+       binary.Write(writer, binary.LittleEndian, response)
+       packet := writer.Bytes()
+       peer.cookieGenerator.AddMacs(packet)
+
+       err = peer.BeginSymmetricSession()
+       if err != nil {
+               peer.device.log.Error.Println(peer, ": Failed to derive keypair:", err)
+               return err
+       }
+
+       peer.timersSessionDerived()
+       peer.timersAnyAuthenticatedPacketTraversal()
+
+       err = peer.SendBuffer(packet)
+       if err != nil {
+               peer.device.log.Error.Println(peer, ": Failed to send handshake response", err)
+       }
+       return err
+}
+
+func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
+
+       device.log.Debug.Println("Sending cookie reply to:", initiatingElem.endpoint.DstToString())
+
+       sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
+       reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
+       if err != nil {
+               device.log.Error.Println("Failed to create cookie reply:", err)
+               return err
+       }
+
+       var buff [MessageCookieReplySize]byte
+       writer := bytes.NewBuffer(buff[:0])
+       binary.Write(writer, binary.LittleEndian, reply)
+       device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+       if err != nil {
+               device.log.Error.Println("Failed to send cookie reply:", err)
+       }
+       return err
 }
 
-/* Called when a new authenticated message has been send
- *
- */
 func (peer *Peer) keepKeyFreshSending() {
-       kp := peer.keypairs.Current()
-       if kp == nil {
+       keypair := peer.keypairs.Current()
+       if keypair == nil {
                return
        }
-       nonce := atomic.LoadUint64(&kp.sendNonce)
-       if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) {
+       nonce := atomic.LoadUint64(&keypair.sendNonce)
+       if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) {
                peer.SendHandshakeInitiation(false)
        }
 }
@@ -217,14 +279,14 @@ func (device *Device) RoutineReadFromTUN() {
                                continue
                        }
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
-                       peer = device.routing.table.LookupIPv4(dst)
+                       peer = device.allowedips.LookupIPv4(dst)
 
                case ipv6.Version:
                        if len(elem.packet) < ipv6.HeaderLen {
                                continue
                        }
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
-                       peer = device.routing.table.LookupIPv6(dst)
+                       peer = device.allowedips.LookupIPv6(dst)
 
                default:
                        logDebug.Println("Received packet with unknown IP version")
index 9e633eec6d76ca4725baa4f3221a423a1f9893c3..e13237694034512f95ad0307c5eb61314b548d21 100644 (file)
--- a/timers.go
+++ b/timers.go
@@ -104,30 +104,7 @@ func expiredNewHandshake(peer *Peer) {
 
 func expiredZeroKeyMaterial(peer *Peer) {
        peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
-
-       hs := &peer.handshake
-       hs.mutex.Lock()
-
-       kp := &peer.keypairs
-       kp.mutex.Lock()
-
-       if kp.previous != nil {
-               peer.device.DeleteKeypair(kp.previous)
-               kp.previous = nil
-       }
-       if kp.current != nil {
-               peer.device.DeleteKeypair(kp.current)
-               kp.current = nil
-       }
-       if kp.next != nil {
-               peer.device.DeleteKeypair(kp.next)
-               kp.next = nil
-       }
-       kp.mutex.Unlock()
-
-       peer.device.indexTable.Delete(hs.localIndex)
-       hs.Clear()
-       hs.mutex.Unlock()
+       peer.ZeroAndFlushAll()
 }
 
 func expiredPersistentKeepalive(peer *Peer) {
@@ -209,7 +186,6 @@ func (peer *Peer) timersInit() {
        peer.timers.handshakeAttempts = 0
        peer.timers.sentLastMinuteHandshake = false
        peer.timers.needAnotherKeepalive = false
-       peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
 }
 
 func (peer *Peer) timersStop() {
diff --git a/uapi.go b/uapi.go
index 90c400aa923081d21fa264032ba94b17f0cd7238..53a598e849e06457bfd4ebd0f7ed0ff0e201db3f 100644 (file)
--- a/uapi.go
+++ b/uapi.go
@@ -46,19 +46,16 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                device.net.mutex.RLock()
                defer device.net.mutex.RUnlock()
 
-               device.noise.mutex.RLock()
-               defer device.noise.mutex.RUnlock()
+               device.staticIdentity.mutex.RLock()
+               defer device.staticIdentity.mutex.RUnlock()
 
-               device.routing.mutex.RLock()
-               defer device.routing.mutex.RUnlock()
-
-               device.peers.mutex.Lock()
-               defer device.peers.mutex.Unlock()
+               device.peers.mutex.RLock()
+               defer device.peers.mutex.RUnlock()
 
                // serialize device related values
 
-               if !device.noise.privateKey.IsZero() {
-                       send("private_key=" + device.noise.privateKey.ToHex())
+               if !device.staticIdentity.privateKey.IsZero() {
+                       send("private_key=" + device.staticIdentity.privateKey.ToHex())
                }
 
                if device.net.port != 0 {
@@ -91,7 +88,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
                        send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
 
-                       for _, ip := range device.routing.table.EntriesForPeer(peer) {
+                       for _, ip := range device.allowedips.EntriesForPeer(peer) {
                                send("allowed_ip=" + ip.String())
                        }
 
@@ -234,13 +231,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // ignore peer with public key of device
 
-                               device.noise.mutex.RLock()
-                               equals := device.noise.publicKey.Equals(publicKey)
-                               device.noise.mutex.RUnlock()
+                               device.staticIdentity.mutex.RLock()
+                               dummy = device.staticIdentity.publicKey.Equals(publicKey)
+                               device.staticIdentity.mutex.RUnlock()
 
-                               if equals {
+                               if dummy {
                                        peer = &Peer{}
-                                       dummy = true
                                }
 
                                // find peer referenced
@@ -348,9 +344,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        continue
                                }
 
-                               device.routing.mutex.Lock()
-                               device.routing.table.RemoveByPeer(peer)
-                               device.routing.mutex.Unlock()
+                               device.allowedips.RemoveByPeer(peer)
 
                        case "allowed_ip":
 
@@ -367,9 +361,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                                ones, _ := network.Mask.Size()
-                               device.routing.mutex.Lock()
-                               device.routing.table.Insert(network.IP, uint(ones), peer)
-                               device.routing.mutex.Unlock()
+                               device.allowedips.Insert(network.IP, uint(ones), peer)
 
                        default:
                                logError.Println("Invalid UAPI key (peer configuration):", key)