]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Rework index hashtable
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 16:23:40 +0000 (18:23 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 13 May 2018 16:26:09 +0000 (18:26 +0200)
device.go
indextable.go [moved from index.go with 65% similarity]
keypair.go
noise-protocol.go
peer.go
receive.go
send.go
timers.go

index e127b5b376da50eb6d8af6b93ea30c5ec0a0c94c..3db3609b08fddf16290a34d55af9aa757888a810 100644 (file)
--- a/device.go
+++ b/device.go
@@ -56,8 +56,8 @@ type Device struct {
 
        // unprotected / "self-synchronising resources"
 
-       indices IndexTable
-       mac     CookieChecker
+       indexTable IndexTable
+       mac        CookieChecker
 
        rate struct {
                underLoadUntil atomic.Value
@@ -283,7 +283,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
 
        // initialize noise & crypt-key routine
 
-       device.indices.Init()
+       device.indexTable.Init()
        device.routing.table.Reset()
 
        // setup buffer pool
similarity index 65%
rename from index.go
rename to indextable.go
index 4a78d556b849b95a90a1a120cdccfeffdef8d42c..2d947cdbb0a0c0a55c25ded70998445cfd2fa7ca 100644 (file)
--- a/index.go
@@ -7,18 +7,14 @@ package main
 
 import (
        "crypto/rand"
-       "encoding/binary"
        "sync"
+       "unsafe"
 )
 
-/* Index=0 is reserved for unset indecies
- *
- */
-
 type IndexTableEntry struct {
        peer      *Peer
        handshake *Handshake
-       keyPair   *Keypair
+       keypair   *Keypair
 }
 
 type IndexTable struct {
@@ -27,34 +23,38 @@ type IndexTable struct {
 }
 
 func randUint32() (uint32, error) {
-       var buff [4]byte
-       _, err := rand.Read(buff[:])
-       value := binary.LittleEndian.Uint32(buff[:])
-       return value, err
+       var integer [4]byte
+       _, err := rand.Read(integer[:])
+       return *(*uint32)(unsafe.Pointer(&integer[0])), err
 }
 
 func (table *IndexTable) Init() {
        table.mutex.Lock()
+       defer table.mutex.Unlock()
        table.table = make(map[uint32]IndexTableEntry)
-       table.mutex.Unlock()
 }
 
 func (table *IndexTable) Delete(index uint32) {
-       if index == 0 {
-               return
-       }
        table.mutex.Lock()
+       defer table.mutex.Unlock()
        delete(table.table, index)
-       table.mutex.Unlock()
 }
 
-func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
+func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
        table.mutex.Lock()
-       table.table[key] = value
-       table.mutex.Unlock()
+       defer table.mutex.Unlock()
+       entry, ok := table.table[index]
+       if !ok {
+               return
+       }
+       table.table[index] = IndexTableEntry{
+               peer:      entry.peer,
+               keypair:   keypair,
+               handshake: nil,
+       }
 }
 
-func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
+func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
        for {
                // generate random index
 
@@ -62,9 +62,6 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
                if err != nil {
                        return index, err
                }
-               if index == 0 {
-                       continue
-               }
 
                // check if index used
 
@@ -75,7 +72,7 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
                        continue
                }
 
-               // map index to handshake
+               // check again while locked
 
                table.mutex.Lock()
                _, found := table.table[index]
@@ -85,8 +82,8 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
                }
                table.table[index] = IndexTableEntry{
                        peer:      peer,
-                       handshake: &peer.handshake,
-                       keyPair:   nil,
+                       handshake: handshake,
+                       keypair:   nil,
                }
                table.mutex.Unlock()
                return index, nil
index 07a183db6531debc57c8009a2d0a5f8f3d65ca14..6f6f7c01bea0b2e37f4e8c84debbce94a4414e59 100644 (file)
@@ -44,6 +44,6 @@ func (kp *Keypairs) Current() *Keypair {
 
 func (device *Device) DeleteKeypair(key *Keypair) {
        if key != nil {
-               device.indices.Delete(key.localIndex)
+               device.indexTable.Delete(key.localIndex)
        }
 }
index 3abbe4bec43cc806b0529659fa11c15acf25ebad..82d553e107b74584b7e0e46ac592a6e169855fe5 100644 (file)
@@ -161,7 +161,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        defer handshake.mutex.Unlock()
 
        if isZero(handshake.precomputedStaticStatic[:]) {
-               return nil, errors.New("Static shared secret is zero")
+               return nil, errors.New("static shared secret is zero")
        }
 
        // create ephemeral key
@@ -176,8 +176,8 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 
        // assign index
 
-       device.indices.Delete(handshake.localIndex)
-       handshake.localIndex, err = device.indices.NewIndex(peer)
+       device.indexTable.Delete(handshake.localIndex)
+       handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
 
        if err != nil {
                return nil, err
@@ -328,14 +328,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        defer handshake.mutex.Unlock()
 
        if handshake.state != HandshakeInitiationConsumed {
-               return nil, errors.New("handshake initation must be consumed first")
+               return nil, errors.New("handshake initiation must be consumed first")
        }
 
        // assign index
 
        var err error
-       device.indices.Delete(handshake.localIndex)
-       handshake.localIndex, err = device.indices.NewIndex(peer)
+       device.indexTable.Delete(handshake.localIndex)
+       handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
        if err != nil {
                return nil, err
        }
@@ -393,9 +393,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                return nil
        }
 
-       // lookup handshake by reciever
+       // lookup handshake by receiver
 
-       lookup := device.indices.Lookup(msg.Receiver)
+       lookup := device.indexTable.Lookup(msg.Receiver)
        handshake := lookup.handshake
        if handshake == nil {
                return nil
@@ -528,35 +528,28 @@ func (peer *Peer) NewKeypair() *Keypair {
 
        // create AEAD instances
 
-       keyPair := new(Keypair)
-       keyPair.send, _ = chacha20poly1305.New(sendKey[:])
-       keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
+       keypair := new(Keypair)
+       keypair.send, _ = chacha20poly1305.New(sendKey[:])
+       keypair.receive, _ = chacha20poly1305.New(recvKey[:])
 
        setZero(sendKey[:])
        setZero(recvKey[:])
 
-       keyPair.created = time.Now()
-       keyPair.sendNonce = 0
-       keyPair.replayFilter.Init()
-       keyPair.isInitiator = isInitiator
-       keyPair.localIndex = peer.handshake.localIndex
-       keyPair.remoteIndex = peer.handshake.remoteIndex
+       keypair.created = time.Now()
+       keypair.sendNonce = 0
+       keypair.replayFilter.Init()
+       keypair.isInitiator = isInitiator
+       keypair.localIndex = peer.handshake.localIndex
+       keypair.remoteIndex = peer.handshake.remoteIndex
 
        // remap index
 
-       device.indices.Insert(
-               handshake.localIndex,
-               IndexTableEntry{
-                       peer:      peer,
-                       keyPair:   keyPair,
-                       handshake: nil,
-               },
-       )
+       device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
        handshake.localIndex = 0
 
        // rotate key pairs
 
-       kp := &peer.keyPairs
+       kp := &peer.keypairs
        kp.mutex.Lock()
 
        peer.timersSessionDerived()
@@ -574,14 +567,14 @@ func (peer *Peer) NewKeypair() *Keypair {
                        kp.previous = current
                }
                device.DeleteKeypair(previous)
-               kp.current = keyPair
+               kp.current = keypair
        } else {
-               kp.next = keyPair
+               kp.next = keypair
                device.DeleteKeypair(next)
                kp.previous = nil
                device.DeleteKeypair(previous)
        }
        kp.mutex.Unlock()
 
-       return keyPair
+       return keypair
 }
diff --git a/peer.go b/peer.go
index 242729ebd222b823a628faa171eab8fca99cde7b..f49f806590adbd9f0cac6c6439b77cb1d8c07d92 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -20,7 +20,7 @@ const (
 type Peer struct {
        isRunning                   AtomicBool
        mutex                       sync.RWMutex
-       keyPairs                    Keypairs
+       keypairs                    Keypairs
        handshake                   Handshake
        device                      *Device
        endpoint                    Endpoint
@@ -234,7 +234,7 @@ func (peer *Peer) Stop() {
 
        // clear key pairs
 
-       kp := &peer.keyPairs
+       kp := &peer.keypairs
        kp.mutex.Lock()
 
        device.DeleteKeypair(kp.previous)
@@ -250,7 +250,7 @@ func (peer *Peer) Stop() {
 
        hs := &peer.handshake
        hs.mutex.Lock()
-       device.indices.Delete(hs.localIndex)
+       device.indexTable.Delete(hs.localIndex)
        hs.Clear()
        hs.mutex.Unlock()
 
index 0f22a3f252959ef091405a44654e36dc3daa79f9..60a2510abe22a6a2f0105c96ad86884b30645736 100644 (file)
@@ -31,7 +31,7 @@ type QueueInboundElement struct {
        buffer   *[MaxMessageSize]byte
        packet   []byte
        counter  uint64
-       keyPair  *Keypair
+       keypair  *Keypair
        endpoint Endpoint
 }
 
@@ -107,7 +107,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
        if peer.timers.sentLastMinuteHandshake {
                return
        }
-       kp := peer.keyPairs.Current()
+       kp := peer.keypairs.Current()
        if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
                peer.timers.sentLastMinuteHandshake = true
                peer.SendHandshakeInitiation(false)
@@ -183,15 +183,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        receiver := binary.LittleEndian.Uint32(
                                packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
                        )
-                       value := device.indices.Lookup(receiver)
-                       keyPair := value.keyPair
-                       if keyPair == nil {
+                       value := device.indexTable.Lookup(receiver)
+                       keypair := value.keypair
+                       if keypair == nil {
                                continue
                        }
 
                        // check key-pair expiry
 
-                       if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+                       if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
                                continue
                        }
 
@@ -201,7 +201,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        elem := &QueueInboundElement{
                                packet:   packet,
                                buffer:   buffer,
-                               keyPair:  keyPair,
+                               keypair:  keypair,
                                dropped:  AtomicFalse,
                                endpoint: endpoint,
                        }
@@ -296,7 +296,7 @@ func (device *Device) RoutineDecryption() {
 
                        var err error
                        elem.counter = binary.LittleEndian.Uint64(counter)
-                       elem.packet, err = elem.keyPair.receive.Open(
+                       elem.packet, err = elem.keypair.receive.Open(
                                content[:0],
                                nonce[:],
                                content,
@@ -358,7 +358,7 @@ func (device *Device) RoutineHandshake() {
 
                        // lookup peer from index
 
-                       entry := device.indices.Lookup(reply.Receiver)
+                       entry := device.indexTable.Lookup(reply.Receiver)
 
                        if entry.peer == nil {
                                continue
@@ -587,7 +587,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        // check for replay
 
-                       if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+                       if !elem.keypair.replayFilter.ValidateCounter(elem.counter) {
                                continue
                        }
 
@@ -599,9 +599,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        // check if using new key-pair
 
-                       kp := &peer.keyPairs
+                       kp := &peer.keypairs
                        kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
-                       if kp.next == elem.keyPair {
+                       if kp.next == elem.keypair {
                                old := kp.previous
                                kp.previous = kp.current
                                device.DeleteKeypair(old)
diff --git a/send.go b/send.go
index 1b35e275eae275d89b7137fc810277fa4d60c026..35e0d0008f581ed802c8f999841fdd44eee76bef 100644 (file)
--- a/send.go
+++ b/send.go
@@ -47,7 +47,7 @@ type QueueOutboundElement struct {
        buffer  *[MaxMessageSize]byte // slice holding the packet data
        packet  []byte                // slice of "buffer" (always!)
        nonce   uint64                // nonce for encryption
-       keyPair *Keypair              // key-pair for encryption
+       keypair *Keypair              // key-pair for encryption
        peer    *Peer                 // related peer
 }
 
@@ -161,7 +161,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
  *
  */
 func (peer *Peer) keepKeyFreshSending() {
-       kp := peer.keyPairs.Current()
+       kp := peer.keypairs.Current()
        if kp == nil {
                return
        }
@@ -260,7 +260,7 @@ func (peer *Peer) FlushNonceQueue() {
  * Obs. A single instance per peer
  */
 func (peer *Peer) RoutineNonce() {
-       var keyPair *Keypair
+       var keypair *Keypair
 
        device := peer.device
        logDebug := device.log.Debug
@@ -291,9 +291,9 @@ func (peer *Peer) RoutineNonce() {
                        // wait for key pair
 
                        for {
-                               keyPair = peer.keyPairs.Current()
-                               if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
-                                       if time.Now().Sub(keyPair.created) < RejectAfterTime {
+                               keypair = peer.keypairs.Current()
+                               if keypair != nil && keypair.sendNonce < RejectAfterMessages {
+                                       if time.Now().Sub(keypair.created) < RejectAfterTime {
                                                break
                                        }
                                }
@@ -328,12 +328,12 @@ func (peer *Peer) RoutineNonce() {
                        // populate work element
 
                        elem.peer = peer
-                       elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+                       elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
                        // double check in case of race condition added by future code
                        if elem.nonce >= RejectAfterMessages {
                                goto NextPacket
                        }
-                       elem.keyPair = keyPair
+                       elem.keypair = keypair
                        elem.dropped = AtomicFalse
                        elem.mutex.Lock()
 
@@ -392,7 +392,7 @@ func (device *Device) RoutineEncryption() {
                        fieldNonce := header[8:16]
 
                        binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
-                       binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
+                       binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
                        binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
 
                        // pad content to multiple of 16
@@ -408,7 +408,7 @@ func (device *Device) RoutineEncryption() {
                        // encrypt content and release to consumer
 
                        binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
-                       elem.packet = elem.keyPair.send.Seal(
+                       elem.packet = elem.keypair.send.Seal(
                                header,
                                nonce[:],
                                elem.packet,
index 5c72efd4428d6c0db01f7f8b7afd9b4596a82b0f..9e633eec6d76ca4725baa4f3221a423a1f9893c3 100644 (file)
--- a/timers.go
+++ b/timers.go
@@ -108,7 +108,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
        hs := &peer.handshake
        hs.mutex.Lock()
 
-       kp := &peer.keyPairs
+       kp := &peer.keypairs
        kp.mutex.Lock()
 
        if kp.previous != nil {
@@ -125,7 +125,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
        }
        kp.mutex.Unlock()
 
-       peer.device.indices.Delete(hs.localIndex)
+       peer.device.indexTable.Delete(hs.localIndex)
        hs.Clear()
        hs.mutex.Unlock()
 }