]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Begin implementation of outbound work queue
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 26 Jun 2017 20:07:29 +0000 (22:07 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 26 Jun 2017 20:07:29 +0000 (22:07 +0200)
src/device.go
src/index.go
src/keypair.go
src/noise_protocol.go
src/peer.go
src/send.go

index ce10a634001dce69ee6323fc8dcea5c9ae2f8f8a..4b8cda01fd80731db2965147263d10ea7d7d8f00 100644 (file)
@@ -2,11 +2,14 @@ package main
 
 import (
        "log"
+       "net"
        "sync"
 )
 
 type Device struct {
        mtu               int
+       source            *net.UDPAddr // UDP source address
+       conn              *net.UDPConn // UDP "connection"
        mutex             sync.RWMutex
        peers             map[NoisePublicKey]*Peer
        indices           IndexTable
index 81f71e9160ac4f8fba23347db4494f16352f5781..917851056c358a4c8cf0ebea3d0fb0420964dc67 100644 (file)
@@ -11,10 +11,15 @@ import (
  *
  */
 
+type IndexTableEntry struct {
+       peer      *Peer
+       handshake *Handshake
+       keyPair   *KeyPair
+}
+
 type IndexTable struct {
-       mutex      sync.RWMutex
-       keypairs   map[uint32]*KeyPair
-       handshakes map[uint32]*Peer
+       mutex sync.RWMutex
+       table map[uint32]IndexTableEntry
 }
 
 func randUint32() (uint32, error) {
@@ -32,52 +37,66 @@ func randUint32() (uint32, error) {
 
 func (table *IndexTable) Init() {
        table.mutex.Lock()
-       defer table.mutex.Unlock()
-       table.keypairs = make(map[uint32]*KeyPair)
-       table.handshakes = make(map[uint32]*Peer)
+       table.table = make(map[uint32]IndexTableEntry)
+       table.mutex.Unlock()
 }
 
-func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
+func (table *IndexTable) ClearIndex(index uint32) {
+       if index == 0 {
+               return
+       }
+       table.mutex.Lock()
+       delete(table.table, index)
+       table.mutex.Unlock()
+}
+
+func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
        table.mutex.Lock()
-       defer table.mutex.Unlock()
+       table.table[key] = value
+       table.mutex.Unlock()
+}
+
+func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
        for {
                // generate random index
 
-               id, err := randUint32()
+               index, err := randUint32()
                if err != nil {
-                       return id, err
+                       return index, err
                }
-               if id == 0 {
+               if index == 0 {
                        continue
                }
 
                // check if index used
 
-               _, ok := table.keypairs[id]
-               if ok {
-                       continue
-               }
-               _, ok = table.handshakes[id]
+               table.mutex.RLock()
+               _, ok := table.table[index]
                if ok {
                        continue
                }
+               table.mutex.RUnlock()
 
-               // clean old index
+               // replace index
 
-               delete(table.handshakes, peer.handshake.localIndex)
-               table.handshakes[id] = peer
-               return id, nil
+               table.mutex.Lock()
+               _, found := table.table[index]
+               if found {
+                       table.mutex.Unlock()
+                       continue
+               }
+               table.table[index] = IndexTableEntry{
+                       peer:      peer,
+                       handshake: &peer.handshake,
+                       keyPair:   nil,
+               }
+               table.mutex.Unlock()
+               return index, nil
        }
 }
 
-func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
-       table.mutex.RLock()
-       defer table.mutex.RUnlock()
-       return table.keypairs[id]
-}
-
-func (table *IndexTable) LookupHandshake(id uint32) *Peer {
+func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
-       return table.handshakes[id]
+       return table.table[id]
 }
index e7961a8eb843af375261eb82c76f06fc624a8dd1..53e123ff3b73d97c0f39217f9738242350048aad 100644 (file)
@@ -16,6 +16,18 @@ type KeyPairs struct {
        mutex      sync.RWMutex
        current    *KeyPair
        previous   *KeyPair
-       next       *KeyPair
-       newKeyPair chan bool
+       next       *KeyPair  // not yet "confirmed by transport"
+       newKeyPair chan bool // signals when "current" has been updated
+}
+
+func (kp *KeyPairs) Init() {
+       kp.mutex.Lock()
+       kp.newKeyPair = make(chan bool, 5)
+       kp.mutex.Unlock()
+}
+
+func (kp *KeyPairs) Current() *KeyPair {
+       kp.mutex.RLock()
+       defer kp.mutex.RUnlock()
+       return kp.current
 }
index a16908a5e9d56b6a52285111b9ae413bdace7ab8..bf1db9b39639897a9916095bec0220e5861676c0 100644 (file)
@@ -120,13 +120,15 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
                return nil, err
        }
 
+       device.indices.ClearIndex(handshake.localIndex)
+       handshake.localIndex, err = device.indices.NewIndex(peer)
+
        // assign index
 
        var msg MessageInitiation
 
        msg.Type = MessageInitiationType
        msg.Ephemeral = handshake.localEphemeral.publicKey()
-       handshake.localIndex, err = device.indices.NewIndex(peer)
 
        if err != nil {
                return nil, err
@@ -249,6 +251,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        // assign index
 
        var err error
+       device.indices.ClearIndex(handshake.localIndex)
        handshake.localIndex, err = device.indices.NewIndex(peer)
        if err != nil {
                return nil, err
@@ -299,11 +302,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 
        // lookup handshake by reciever
 
-       peer := device.indices.LookupHandshake(msg.Reciever)
-       if peer == nil {
+       lookup := device.indices.Lookup(msg.Reciever)
+       handshake := lookup.handshake
+       if handshake == nil {
                return nil
        }
-       handshake := &peer.handshake
+
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
        if handshake.state != HandshakeInitiationCreated {
@@ -345,7 +349,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        handshake.remoteIndex = msg.Sender
        handshake.state = HandshakeResponseConsumed
 
-       return peer
+       return lookup.peer
 }
 
 func (peer *Peer) NewKeyPair() *KeyPair {
@@ -355,13 +359,16 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
        // derive keys
 
+       var isInitiator bool
        var sendKey [chacha20poly1305.KeySize]byte
        var recvKey [chacha20poly1305.KeySize]byte
 
        if handshake.state == HandshakeResponseConsumed {
                sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+               isInitiator = true
        } else if handshake.state == HandshakeResponseCreated {
                recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+               isInitiator = false
        } else {
                return nil
        }
@@ -369,16 +376,40 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        // create AEAD instances
 
        var keyPair KeyPair
+
        keyPair.send, _ = chacha20poly1305.New(sendKey[:])
        keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
        keyPair.sendNonce = 0
        keyPair.recvNonce = 0
 
+       // remap index
+
+       peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
+               peer:      peer,
+               keyPair:   &keyPair,
+               handshake: nil,
+       })
+       handshake.localIndex = 0
+
+       // rotate key pairs
+
+       func() {
+               kp := &peer.keyPairs
+               kp.mutex.Lock()
+               defer kp.mutex.Unlock()
+               if isInitiator {
+                       kp.previous = peer.keyPairs.current
+                       kp.current = &keyPair
+                       kp.newKeyPair <- true
+               } else {
+                       kp.next = &keyPair
+               }
+       }()
+
        // zero handshake
 
        handshake.chainKey = [blake2s.Size]byte{}
        handshake.localEphemeral = NoisePrivateKey{}
        peer.handshake.state = HandshakeZeroed
-
        return &keyPair
 }
index 42b9e8d5de62c808057c7bdad5486e87434f328f..6a879cb3454bf622d40bd404fac2ed0982df9fcf 100644 (file)
@@ -14,8 +14,7 @@ const (
 
 type Peer struct {
        mutex                       sync.RWMutex
-       endpointIP                  net.IP        //
-       endpointPort                uint16        //
+       endpoint                    *net.UDPAddr
        persistentKeepaliveInterval time.Duration // 0 = disabled
        keyPairs                    KeyPairs
        handshake                   Handshake
@@ -35,6 +34,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 
        peer.mutex.Lock()
        peer.device = device
+       peer.keyPairs.Init()
        peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
 
        // map public key
index 97903206d99a78cb9e9530f75bb885e9b6fdbc61..da5905d5491aa6960cf7070b210e5939ff0d03e8 100644 (file)
@@ -1,9 +1,11 @@
 package main
 
 import (
+       "encoding/binary"
+       "golang.org/x/crypto/chacha20poly1305"
        "net"
        "sync"
-       "sync/atomic"
+       "time"
 )
 
 /* Handles outbound flow
@@ -70,85 +72,115 @@ func (device *Device) SendPacket(packet []byte) {
  *
  * TODO: avoid dynamic allocation of work queue elements
  */
-func (peer *Peer) ConsumeOutboundPackets() {
+func (peer *Peer) RoutineOutboundNonceWorker() {
+       var packet []byte
+       var keyPair *KeyPair
+       var flushTimer time.Timer
+
        for {
-               // wait for key pair
-               keyPair := func() *KeyPair {
-                       peer.keyPairs.mutex.RLock()
-                       defer peer.keyPairs.mutex.RUnlock()
-                       return peer.keyPairs.current
-               }()
-               if keyPair == nil {
-                       if len(peer.queueOutboundRouting) > 0 {
-                               // TODO: start handshake
-                               <-peer.keyPairs.newKeyPair
-                       }
-                       continue
+
+               // wait for packet
+
+               if packet == nil {
+                       packet = <-peer.queueOutboundRouting
                }
 
-               // assign packets key pair
-               for {
+               // wait for key pair
+
+               for keyPair == nil {
+                       flushTimer.Reset(time.Second * 10)
+                       // TODO: Handshake or NOP
                        select {
                        case <-peer.keyPairs.newKeyPair:
-                       default:
-                       case <-peer.keyPairs.newKeyPair:
-                       case packet := <-peer.queueOutboundRouting:
+                               keyPair = peer.keyPairs.Current()
+                               continue
+                       case <-flushTimer.C:
+                               size := len(peer.queueOutboundRouting)
+                               for i := 0; i < size; i += 1 {
+                                       <-peer.queueOutboundRouting
+                               }
+                               packet = nil
+                       }
+                       break
+               }
+
+               // process current packet
+
+               if packet != nil {
 
-                               // create new work element
+                       // create work element
 
-                               work := new(OutboundWorkQueueElement)
-                               work.wg.Add(1)
-                               work.keyPair = keyPair
-                               work.packet = packet
-                               work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+                       work := new(OutboundWorkQueueElement)
+                       work.wg.Add(1)
+                       work.keyPair = keyPair
+                       work.packet = packet
+                       work.nonce = keyPair.sendNonce
 
-                               peer.queueOutbound <- work
+                       packet = nil
+                       peer.queueOutbound <- work
+                       keyPair.sendNonce += 1
 
-                               // drop packets until there is room
+                       // drop packets until there is space
 
+                       func() {
                                for {
                                        select {
                                        case peer.device.queueWorkOutbound <- work:
-                                               break
+                                               return
                                        default:
                                                drop := <-peer.device.queueWorkOutbound
                                                drop.packet = nil
                                                drop.wg.Done()
                                        }
                                }
-                       }
+                       }()
                }
        }
 }
 
+/* Go routine
+ *
+ * sequentially reads packets from queue and sends to endpoint
+ *
+ */
 func (peer *Peer) RoutineSequential() {
        for work := range peer.queueOutbound {
                work.wg.Wait()
+
+               // check if dropped ("ghost packet")
+
                if work.packet == nil {
                        continue
                }
+
+               //
+
        }
 }
 
-func (device *Device) EncryptionWorker() {
-       for {
-               work := <-device.queueWorkOutbound
-
-               func() {
-                       defer work.wg.Done()
+func (device *Device) RoutineEncryptionWorker() {
+       var nonce [chacha20poly1305.NonceSize]byte
+       for work := range device.queueWorkOutbound {
+               // pad packet
 
-                       // pad packet
-                       padding := device.mtu - len(work.packet)
-                       if padding < 0 {
-                               work.packet = nil
-                               return
-                       }
-                       for n := 0; n < padding; n += 1 {
-                               work.packet = append(work.packet, 0) // TODO: gotta be a faster way
-                       }
+               padding := device.mtu - len(work.packet)
+               if padding < 0 {
+                       work.packet = nil
+                       work.wg.Done()
+               }
+               for n := 0; n < padding; n += 1 {
+                       work.packet = append(work.packet, 0)
+               }
 
-                       //
+               // encrypt
 
-               }()
+               binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
+               work.packet = work.keyPair.send.Seal(
+                       work.packet[:0],
+                       nonce[:],
+                       work.packet,
+                       nil,
+               )
+               work.wg.Done()
        }
 }