]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Begin work on outbound packet flow
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 26 Jun 2017 11:14:02 +0000 (13:14 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 26 Jun 2017 11:14:02 +0000 (13:14 +0200)
src/cookie.go [new file with mode: 0644]
src/device.go
src/keypair.go
src/main.go
src/noise_protocol.go
src/noise_test.go
src/peer.go
src/routing.go
src/send.go [new file with mode: 0644]

diff --git a/src/cookie.go b/src/cookie.go
new file mode 100644 (file)
index 0000000..a6987a2
--- /dev/null
@@ -0,0 +1,39 @@
+package main
+
+import (
+       "errors"
+       "golang.org/x/crypto/blake2s"
+)
+
+func CalculateCookie(peer *Peer, msg []byte) {
+       size := len(msg)
+
+       if size < blake2s.Size128*2 {
+               panic(errors.New("bug: message too short"))
+       }
+
+       startMac1 := size - (blake2s.Size128 * 2)
+       startMac2 := size - blake2s.Size128
+
+       mac1 := msg[startMac1 : startMac1+blake2s.Size128]
+       mac2 := msg[startMac2 : startMac2+blake2s.Size128]
+
+       peer.mutex.RLock()
+       defer peer.mutex.RUnlock()
+
+       // set mac1
+
+       func() {
+               mac, _ := blake2s.New128(peer.macKey[:])
+               mac.Write(msg[:startMac1])
+               mac.Sum(mac1[:0])
+       }()
+
+       // set mac2
+
+       if peer.cookie != nil {
+               mac, _ := blake2s.New128(peer.cookie)
+               mac.Write(msg[:startMac2])
+               mac.Sum(mac2[:0])
+       }
+}
index 996903421614ef1b53bf5182d09ac5815366f3b8..ce10a634001dce69ee6323fc8dcea5c9ae2f8f8a 100644 (file)
@@ -1,18 +1,22 @@
 package main
 
 import (
+       "log"
        "sync"
 )
 
 type Device struct {
-       mutex        sync.RWMutex
-       peers        map[NoisePublicKey]*Peer
-       indices      IndexTable
-       privateKey   NoisePrivateKey
-       publicKey    NoisePublicKey
-       fwMark       uint32
-       listenPort   uint16
-       routingTable RoutingTable
+       mtu               int
+       mutex             sync.RWMutex
+       peers             map[NoisePublicKey]*Peer
+       indices           IndexTable
+       privateKey        NoisePrivateKey
+       publicKey         NoisePublicKey
+       fwMark            uint32
+       listenPort        uint16
+       routingTable      RoutingTable
+       logger            log.Logger
+       queueWorkOutbound chan *OutboundWorkQueueElement
 }
 
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
index e434c74dacb38676f2c199b035d2cf1d77240118..e7961a8eb843af375261eb82c76f06fc624a8dd1 100644 (file)
@@ -2,11 +2,20 @@ package main
 
 import (
        "crypto/cipher"
+       "sync"
 )
 
 type KeyPair struct {
        recv      cipher.AEAD
-       recvNonce NoiseNonce
+       recvNonce uint64
        send      cipher.AEAD
-       sendNonce NoiseNonce
+       sendNonce uint64
+}
+
+type KeyPairs struct {
+       mutex      sync.RWMutex
+       current    *KeyPair
+       previous   *KeyPair
+       next       *KeyPair
+       newKeyPair chan bool
 }
index af336f03d63ffb6efc4e36fc305cc16a596bc614..b6f6deb92bb698d776aecd7874d35124ae771bc9 100644 (file)
@@ -1,6 +1,8 @@
 package main
 
-import "fmt"
+import (
+       "fmt"
+)
 
 func main() {
        fd, err := CreateTUN("test0")
@@ -8,9 +10,9 @@ func main() {
 
        queue := make(chan []byte, 1000)
 
-       var device Device
+       // var device Device
 
-       go OutgoingRoutingWorker(&device, queue)
+       // go OutgoingRoutingWorker(&device, queue)
 
        for {
                tmp := make([]byte, 1<<16)
index 7f26cf1100577a8333335b8b871f8ca34f03b386..a16908a5e9d56b6a52285111b9ae413bdace7ab8 100644 (file)
@@ -9,9 +9,9 @@ import (
 )
 
 const (
-       HandshakeReset = iota
-       HandshakeInitialCreated
-       HandshakeInitialConsumed
+       HandshakeZeroed = iota
+       HandshakeInitiationCreated
+       HandshakeInitiationConsumed
        HandshakeResponseCreated
        HandshakeResponseConsumed
 )
@@ -24,13 +24,19 @@ const (
 )
 
 const (
-       MessageInitalType         = 1
+       MessageInitiationType     = 1
        MessageResponseType       = 2
        MessageCookieResponseType = 3
        MessageTransportType      = 4
 )
 
-type MessageInital struct {
+/* Type is an 8-bit field, followed by 3 nul bytes,
+ * by marshalling the messages in little-endian byteorder
+ * we can treat these as a 32-bit int
+ *
+ */
+
+type MessageInitiation struct {
        Type      uint32
        Sender    uint32
        Ephemeral NoisePublicKey
@@ -73,9 +79,9 @@ type Handshake struct {
 }
 
 var (
-       ZeroNonce      [chacha20poly1305.NonceSize]byte
        InitalChainKey [blake2s.Size]byte
        InitalHash     [blake2s.Size]byte
+       ZeroNonce      [chacha20poly1305.NonceSize]byte
 )
 
 func init() {
@@ -83,23 +89,23 @@ func init() {
        InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 }
 
-func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
        return KDF1(c[:], data)
 }
 
-func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
        return blake2s.Sum256(append(h[:], data...))
 }
 
-func (h *Handshake) addToHash(data []byte) {
-       h.hash = addToHash(h.hash, data)
+func (h *Handshake) mixHash(data []byte) {
+       h.hash = mixHash(h.hash, data)
 }
 
-func (h *Handshake) addToChainKey(data []byte) {
-       h.chainKey = addToChainKey(h.chainKey, data)
+func (h *Handshake) mixKey(data []byte) {
+       h.chainKey = mixKey(h.chainKey, data)
 }
 
-func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
+func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
        handshake := &peer.handshake
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
@@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 
        var err error
        handshake.chainKey = InitalChainKey
-       handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
+       handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
        handshake.localEphemeral, err = newPrivateKey()
        if err != nil {
                return nil, err
@@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 
        // assign index
 
-       var msg MessageInital
+       var msg MessageInitiation
 
-       msg.Type = MessageInitalType
+       msg.Type = MessageInitiationType
        msg.Ephemeral = handshake.localEphemeral.publicKey()
        handshake.localIndex, err = device.indices.NewIndex(peer)
 
@@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
        }
 
        msg.Sender = handshake.localIndex
-       handshake.addToChainKey(msg.Ephemeral[:])
-       handshake.addToHash(msg.Ephemeral[:])
+       handshake.mixKey(msg.Ephemeral[:])
+       handshake.mixHash(msg.Ephemeral[:])
 
-       // encrypt identity key
+       // encrypt static key
 
        func() {
                var key [chacha20poly1305.KeySize]byte
@@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
                aead, _ := chacha20poly1305.New(key[:])
                aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
        }()
-       handshake.addToHash(msg.Static[:])
+       handshake.mixHash(msg.Static[:])
 
        // encrypt timestamp
 
@@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
                aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
        }()
 
-       handshake.addToHash(msg.Timestamp[:])
-       handshake.state = HandshakeInitialCreated
+       handshake.mixHash(msg.Timestamp[:])
+       handshake.state = HandshakeInitiationCreated
 
        return &msg, nil
 }
 
-func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
-       if msg.Type != MessageInitalType {
-               panic(errors.New("bug: invalid inital message type"))
+func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
+       if msg.Type != MessageInitiationType {
+               return nil
        }
 
-       hash := addToHash(InitalHash, device.publicKey[:])
-       hash = addToHash(hash, msg.Ephemeral[:])
-       chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
+       hash := mixHash(InitalHash, device.publicKey[:])
+       hash = mixHash(hash, msg.Ephemeral[:])
+       chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
 
-       // decrypt identity key
+       // decrypt static key
 
        var err error
        var peerPK NoisePublicKey
@@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
        if err != nil {
                return nil
        }
-       hash = addToHash(hash, msg.Static[:])
+       hash = mixHash(hash, msg.Static[:])
 
        // find peer
 
@@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
        if err != nil {
                return nil
        }
-       hash = addToHash(hash, msg.Timestamp[:])
+       hash = mixHash(hash, msg.Timestamp[:])
 
        // check for replay attack
 
@@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
                return nil
        }
 
-       // check for flood attack
+       // TODO: check for flood attack
 
        // update handshake state
 
@@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
        handshake.remoteIndex = msg.Sender
        handshake.remoteEphemeral = msg.Ephemeral
        handshake.lastTimestamp = timestamp
-       handshake.state = HandshakeInitialConsumed
+       handshake.state = HandshakeInitiationConsumed
        return peer
 }
 
@@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
 
-       if handshake.state != HandshakeInitialConsumed {
-               panic(errors.New("bug: handshake initation must be consumed first"))
+       if handshake.state != HandshakeInitiationConsumed {
+               return nil, errors.New("handshake initation must be consumed first")
        }
 
        // assign index
@@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
                return nil, err
        }
        msg.Ephemeral = handshake.localEphemeral.publicKey()
-       handshake.addToHash(msg.Ephemeral[:])
+       handshake.mixHash(msg.Ephemeral[:])
 
        func() {
                ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
-               handshake.addToChainKey(ss[:])
+               handshake.mixKey(ss[:])
                ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
-               handshake.addToChainKey(ss[:])
+               handshake.mixKey(ss[:])
        }()
 
        // add preshared key (psk)
@@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        var tau [blake2s.Size]byte
        var key [chacha20poly1305.KeySize]byte
        handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
-       handshake.addToHash(tau[:])
+       handshake.mixHash(tau[:])
 
        func() {
                aead, _ := chacha20poly1305.New(key[:])
                aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
-               handshake.addToHash(msg.Empty[:])
+               handshake.mixHash(msg.Empty[:])
        }()
 
        handshake.state = HandshakeResponseCreated
@@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        if msg.Type != MessageResponseType {
-               panic(errors.New("bug: invalid message type"))
+               return nil
        }
 
        // lookup handshake by reciever
@@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        handshake := &peer.handshake
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
-       if handshake.state != HandshakeInitialCreated {
+       if handshake.state != HandshakeInitiationCreated {
                return nil
        }
 
        // finish 3-way DH
 
-       hash := addToHash(handshake.hash, msg.Ephemeral[:])
+       hash := mixHash(handshake.hash, msg.Ephemeral[:])
        chainKey := handshake.chainKey
 
        func() {
                ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
-               chainKey = addToChainKey(chainKey, ss[:])
+               chainKey = mixKey(chainKey, ss[:])
                ss = device.privateKey.sharedSecret(msg.Ephemeral)
-               chainKey = addToChainKey(chainKey, ss[:])
+               chainKey = mixKey(chainKey, ss[:])
        }()
 
        // add preshared key (psk)
@@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        var tau [blake2s.Size]byte
        var key [chacha20poly1305.KeySize]byte
        chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
-       hash = addToHash(hash, tau[:])
+       hash = mixHash(hash, tau[:])
 
        // authenticate
 
@@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        if err != nil {
                return nil
        }
-       hash = addToHash(hash, msg.Empty[:])
+       hash = mixHash(hash, msg.Empty[:])
 
        // update handshake state
 
@@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        keyPair.sendNonce = 0
        keyPair.recvNonce = 0
 
-       peer.handshake.state = HandshakeReset
+       // zero handshake
+
+       handshake.chainKey = [blake2s.Size]byte{}
+       handshake.localEphemeral = NoisePrivateKey{}
+       peer.handshake.state = HandshakeZeroed
 
        return &keyPair
 }
index ddabf8e930d148c2929db549dca12126193c2033..8450c1c25046fc004988ebb182247578c37dab8d 100644 (file)
@@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) {
 
        t.Log("exchange initiation message")
 
-       msg1, err := dev1.CreateMessageInitial(peer2)
+       msg1, err := dev1.CreateMessageInitiation(peer2)
        assertNil(t, err)
 
        packet := make([]byte, 0, 256)
        writer := bytes.NewBuffer(packet)
        err = binary.Write(writer, binary.LittleEndian, msg1)
-       peer := dev2.ConsumeMessageInitial(msg1)
+       peer := dev2.ConsumeMessageInitiation(msg1)
        if peer == nil {
                t.Fatal("handshake failed at initiation message")
        }
index f6eb555ff6ade422198267698257a47b70039b71..42b9e8d5de62c808057c7bdad5486e87434f328f 100644 (file)
@@ -1,39 +1,64 @@
 package main
 
 import (
+       "errors"
+       "golang.org/x/crypto/blake2s"
        "net"
        "sync"
        "time"
 )
 
+const (
+       OutboundQueueSize = 64
+)
+
 type Peer struct {
        mutex                       sync.RWMutex
        endpointIP                  net.IP        //
        endpointPort                uint16        //
        persistentKeepaliveInterval time.Duration // 0 = disabled
+       keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
+       macKey                      [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
+       cookie                      []byte             // cookie
+       cookieExpire                time.Time
+       queueInbound                chan []byte
+       queueOutbound               chan *OutboundWorkQueueElement
+       queueOutboundRouting        chan []byte
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        var peer Peer
 
+       // create peer
+
+       peer.mutex.Lock()
+       peer.device = device
+       peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
+
        // map public key
 
        device.mutex.Lock()
+       _, ok := device.peers[pk]
+       if ok {
+               panic(errors.New("bug: adding existing peer"))
+       }
        device.peers[pk] = &peer
        device.mutex.Unlock()
 
-       // precompute
+       // precompute DH
 
-       peer.mutex.Lock()
-       peer.device = device
-       func(h *Handshake) {
-               h.mutex.Lock()
-               h.remoteStatic = pk
-               h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
-               h.mutex.Unlock()
-       }(&peer.handshake)
+       handshake := &peer.handshake
+       handshake.mutex.Lock()
+       handshake.remoteStatic = pk
+       handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
+
+       // compute mac key
+
+       peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
+
+       handshake.mutex.Unlock()
        peer.mutex.Unlock()
 
        return &peer
index 553df117c3573a3cda8c83bb0cbdb159fd8493be..4189c2582d0b8bedd41c0c897f527efb9015087d 100644 (file)
@@ -2,7 +2,6 @@ package main
 
 import (
        "errors"
-       "fmt"
        "net"
        "sync"
 )
@@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
        defer table.mutex.RUnlock()
        return table.IPv6.Lookup(address)
 }
-
-func OutgoingRoutingWorker(device *Device, queue chan []byte) {
-       for {
-               packet := <-queue
-               switch packet[0] >> 4 {
-
-               case IPv4version:
-                       dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
-                       peer := device.routingTable.LookupIPv4(dst)
-                       fmt.Println("IPv4", peer)
-
-               case IPv6version:
-                       dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
-                       peer := device.routingTable.LookupIPv6(dst)
-                       fmt.Println("IPv6", peer)
-
-               default:
-                       // todo: log
-                       fmt.Println("Unknown IP version")
-               }
-       }
-}
diff --git a/src/send.go b/src/send.go
new file mode 100644 (file)
index 0000000..9790320
--- /dev/null
@@ -0,0 +1,154 @@
+package main
+
+import (
+       "net"
+       "sync"
+       "sync/atomic"
+)
+
+/* Handles outbound flow
+ *
+ * 1. TUN queue
+ * 2. Routing
+ * 3. Per peer queuing
+ * 4. (work queuing)
+ *
+ */
+
+type OutboundWorkQueueElement struct {
+       wg      sync.WaitGroup
+       packet  []byte
+       nonce   uint64
+       keyPair *KeyPair
+}
+
+func (device *Device) SendPacket(packet []byte) {
+
+       // lookup peer
+
+       var peer *Peer
+       switch packet[0] >> 4 {
+       case IPv4version:
+               dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+               peer = device.routingTable.LookupIPv4(dst)
+
+       case IPv6version:
+               dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+               peer = device.routingTable.LookupIPv6(dst)
+
+       default:
+               device.logger.Println("unknown IP version")
+               return
+       }
+
+       if peer == nil {
+               return
+       }
+
+       // insert into peer queue
+
+       for {
+               select {
+               case peer.queueOutboundRouting <- packet:
+               default:
+                       select {
+                       case <-peer.queueOutboundRouting:
+                       default:
+                       }
+                       continue
+               }
+               break
+       }
+}
+
+/* Go routine
+ *
+ *
+ * 1. waits for handshake.
+ * 2. assigns key pair & nonce
+ * 3. inserts to working queue
+ *
+ * TODO: avoid dynamic allocation of work queue elements
+ */
+func (peer *Peer) ConsumeOutboundPackets() {
+       for {
+               // wait for key pair
+               keyPair := func() *KeyPair {
+                       peer.keyPairs.mutex.RLock()
+                       defer peer.keyPairs.mutex.RUnlock()
+                       return peer.keyPairs.current
+               }()
+               if keyPair == nil {
+                       if len(peer.queueOutboundRouting) > 0 {
+                               // TODO: start handshake
+                               <-peer.keyPairs.newKeyPair
+                       }
+                       continue
+               }
+
+               // assign packets key pair
+               for {
+                       select {
+                       case <-peer.keyPairs.newKeyPair:
+                       default:
+                       case <-peer.keyPairs.newKeyPair:
+                       case packet := <-peer.queueOutboundRouting:
+
+                               // create new work element
+
+                               work := new(OutboundWorkQueueElement)
+                               work.wg.Add(1)
+                               work.keyPair = keyPair
+                               work.packet = packet
+                               work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+
+                               peer.queueOutbound <- work
+
+                               // drop packets until there is room
+
+                               for {
+                                       select {
+                                       case peer.device.queueWorkOutbound <- work:
+                                               break
+                                       default:
+                                               drop := <-peer.device.queueWorkOutbound
+                                               drop.packet = nil
+                                               drop.wg.Done()
+                                       }
+                               }
+                       }
+               }
+       }
+}
+
+func (peer *Peer) RoutineSequential() {
+       for work := range peer.queueOutbound {
+               work.wg.Wait()
+               if work.packet == nil {
+                       continue
+               }
+       }
+}
+
+func (device *Device) EncryptionWorker() {
+       for {
+               work := <-device.queueWorkOutbound
+
+               func() {
+                       defer work.wg.Done()
+
+                       // pad packet
+                       padding := device.mtu - len(work.packet)
+                       if padding < 0 {
+                               work.packet = nil
+                               return
+                       }
+                       for n := 0; n < padding; n += 1 {
+                               work.packet = append(work.packet, 0) // TODO: gotta be a faster way
+                       }
+
+                       //
+
+               }()
+       }
+}