]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Completed initial version of outbound flow
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 30 Jun 2017 12:41:08 +0000 (14:41 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 30 Jun 2017 12:41:08 +0000 (14:41 +0200)
17 files changed:
src/config.go
src/constants.go
src/device.go
src/handshake.go
src/helper_test.go
src/index.go
src/keypair.go
src/logger.go
src/macs_test.go
src/main.go
src/misc.go
src/noise_helpers.go
src/noise_protocol.go
src/noise_test.go
src/peer.go
src/send.go
src/tun_linux.go

index 2f8dc76de543023fc0821eef5b75c951cbc84cf1..8281581a194035bacf2be525c25a3c5f2349577c 100644 (file)
@@ -8,7 +8,6 @@ import (
        "net"
        "strconv"
        "strings"
-       "time"
 )
 
 // #include <errno.h>
@@ -51,9 +50,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
                send("private_key=" + device.privateKey.ToHex())
        }
 
-       if device.address != nil {
-               send(fmt.Sprintf("listen_port=%d", device.address.Port))
-       }
+       send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
 
        for _, peer := range device.peers {
                func() {
@@ -106,7 +103,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                }
                key := parts[0]
                value := parts[1]
-               logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
 
                switch key {
 
@@ -118,13 +114,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                device.privateKey = NoisePrivateKey{}
                                device.mutex.Unlock()
                        } else {
-                               device.mutex.Lock()
-                               err := device.privateKey.FromHex(value)
-                               device.mutex.Unlock()
+                               var sk NoisePrivateKey
+                               err := sk.FromHex(value)
                                if err != nil {
                                        logger.Println("Failed to set private_key:", err)
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
+                               device.SetPrivateKey(sk)
                        }
 
                case "listen_port":
@@ -134,12 +130,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                logger.Println("Failed to set listen_port:", err)
                                return &IPCError{Code: ipcErrorInvalidValue}
                        }
-                       device.mutex.Lock()
-                       if device.address == nil {
-                               device.address = &net.UDPAddr{}
-                       }
-                       device.address.Port = port
-                       device.mutex.Unlock()
+                       device.net.mutex.Lock()
+                       device.net.addr.Port = port
+                       device.net.conn, err = net.ListenUDP("udp", device.net.addr)
+                       device.net.mutex.Unlock()
 
                case "fwmark":
                        logger.Println("FWMark not handled yet")
@@ -200,13 +194,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                        case "endpoint":
-                               ip := net.ParseIP(value)
-                               if ip == nil {
+                               addr, err := net.ResolveUDPAddr("udp", value)
+                               if err != nil {
                                        logger.Println("Failed to set endpoint:", value)
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
                                peer.mutex.Lock()
-                               // peer.endpoint = ip FIX
+                               peer.endpoint = addr
                                peer.mutex.Unlock()
 
                        case "persistent_keepalive_interval":
@@ -216,7 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
                                peer.mutex.Lock()
-                               peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
+                               peer.persistentKeepaliveInterval = uint64(secs)
                                peer.mutex.Unlock()
 
                        case "replace_allowed_ips":
index e8cdd63e4645d7f83e0a74ff9f7a1ded5fd74930..34217d2dd290c35345ac47b4fcbdca2ae3f74ec2 100644 (file)
@@ -5,15 +5,15 @@ import (
 )
 
 const (
-       RekeyAfterMessage      = (1 << 64) - (1 << 16) - 1
-       RekeyAfterTime         = time.Second * 120
-       RekeyAttemptTime       = time.Second * 90
-       RekeyTimeout           = time.Second * 5 // TODO: Exponential backoff
-       RejectAfterTime        = time.Second * 180
-       RejectAfterMessage     = (1 << 64) - (1 << 4) - 1
-       KeepaliveTimeout       = time.Second * 10
-       CookieRefreshTime      = time.Second * 2
-       MaxHandshakeAttempTime = time.Second * 90
+       RekeyAfterMessages      = (1 << 64) - (1 << 16) - 1
+       RekeyAfterTime          = time.Second * 120
+       RekeyAttemptTime        = time.Second * 90
+       RekeyTimeout            = time.Second * 5 // TODO: Exponential backoff
+       RejectAfterTime         = time.Second * 180
+       RejectAfterMessages     = (1 << 64) - (1 << 4) - 1
+       KeepaliveTimeout        = time.Second * 10
+       CookieRefreshTime       = time.Second * 2
+       MaxHandshakeAttemptTime = time.Second * 90
 )
 
 const (
index 52ac6a499c5354095f8cf54d040fa1801e5d08fb..a33e923a130b51d2c1ba571c1e9253a772dfc5a3 100644 (file)
@@ -7,16 +7,21 @@ import (
 )
 
 type Device struct {
-       mtu          int
-       fwMark       uint32
-       address      *net.UDPAddr // UDP source address
-       conn         *net.UDPConn // UDP "connection"
+       mtu       int
+       log       *Logger // collection of loggers for levels
+       idCounter uint    // for assigning debug ids to peers
+       fwMark    uint32
+       net       struct {
+               // seperate for performance reasons
+               mutex sync.RWMutex
+               addr  *net.UDPAddr // UDP source address
+               conn  *net.UDPConn // UDP "connection"
+       }
        mutex        sync.RWMutex
        privateKey   NoisePrivateKey
        publicKey    NoisePublicKey
        routingTable RoutingTable
        indices      IndexTable
-       log          *Logger
        queue        struct {
                encryption chan *QueueOutboundElement // parallel work queue
        }
@@ -44,17 +49,29 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
        }
 }
 
-func NewDevice(tun TUNDevice) *Device {
+func NewDevice(tun TUNDevice, logLevel int) *Device {
        device := new(Device)
 
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       device.log = NewLogger()
+       device.log = NewLogger(logLevel)
        device.peers = make(map[NoisePublicKey]*Peer)
        device.indices.Init()
        device.routingTable.Reset()
 
+       // listen
+
+       device.net.mutex.Lock()
+       device.net.conn, _ = net.ListenUDP("udp", device.net.addr)
+       addr := device.net.conn.LocalAddr()
+       device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
+       device.net.mutex.Unlock()
+
+       // create queues
+
+       device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
+
        // start workers
 
        for i := 0; i < runtime.NumCPU(); i += 1 {
@@ -92,5 +109,11 @@ func (device *Device) RemoveAllPeers() {
                peer.mutex.Lock()
                delete(device.peers, key)
                peer.Close()
+               peer.mutex.Unlock()
        }
 }
+
+func (device *Device) Close() {
+       device.RemoveAllPeers()
+       close(device.queue.encryption)
+}
index 238c3396e09e7e476ca505210f3e50ee821986a4..8f8e2f9a80399efc036c8e45ced6b16f83d8dcd1 100644 (file)
@@ -24,91 +24,163 @@ func (peer *Peer) SendKeepAlive() bool {
        return true
 }
 
-func (peer *Peer) RoutineHandshakeInitiator() {
-       var ongoing bool
-       var begun time.Time
-       var attempts uint
-       var timeout time.Timer
-
-       device := peer.device
-       work := new(QueueOutboundElement)
-       buffer := make([]byte, 0, 1024)
-
-       queueHandshakeInitiation := func() error {
-               work.mutex.Lock()
-               defer work.mutex.Unlock()
+func StoppedTimer() *time.Timer {
+       timer := time.NewTimer(time.Hour)
+       if !timer.Stop() {
+               <-timer.C
+       }
+       return timer
+}
 
-               // create initiation
+/* Called when a new authenticated message has been send
+ *
+ * TODO: This might be done in a faster way
+ */
+func (peer *Peer) KeepKeyFreshSending() {
+       send := func() bool {
+               peer.keyPairs.mutex.RLock()
+               defer peer.keyPairs.mutex.RUnlock()
 
-               msg, err := device.CreateMessageInitiation(peer)
-               if err != nil {
-                       return err
+               kp := peer.keyPairs.current
+               if kp == nil {
+                       return false
                }
 
-               // create "work" element
+               if !kp.isInitiator {
+                       return false
+               }
 
-               writer := bytes.NewBuffer(buffer[:0])
-               binary.Write(writer, binary.LittleEndian, &msg)
-               work.packet = writer.Bytes()
-               peer.mac.AddMacs(work.packet)
-               peer.InsertOutbound(work)
-               return nil
+               nonce := atomic.LoadUint64(&kp.sendNonce)
+               if nonce > RekeyAfterMessages {
+                       return true
+               }
+               return time.Now().Sub(kp.created) > RekeyAfterTime
+       }()
+       if send {
+               sendSignal(peer.signal.handshakeBegin)
        }
+}
 
-       for {
-               select {
-               case <-peer.signal.stopInitiator:
-                       return
-
-               case <-peer.signal.newHandshake:
-                       if ongoing {
-                               continue
-                       }
-
-                       // create handshake
-
-                       err := queueHandshakeInitiation()
-                       if err != nil {
-                               device.log.Error.Println("Failed to create initiation message:", err)
-                       }
-
-                       // log when we began
-
-                       begun = time.Now()
-                       ongoing = true
-                       attempts = 0
-                       timeout.Reset(RekeyTimeout)
-
-               case <-peer.timer.sendKeepalive.C:
-
-                       // active keep-alives
-
-                       peer.SendKeepAlive()
+/* This is the state machine for handshake initiation
+ *
+ * Associated with this routine is the signal "handshakeBegin"
+ * The routine will read from the "handshakeBegin" channel
+ * at most every RekeyTimeout or with exponential backoff
+ *
+ * Implements exponential backoff for retries
+ */
+func (peer *Peer) RoutineHandshakeInitiator() {
+       work := new(QueueOutboundElement)
+       device := peer.device
+       buffer := make([]byte, 1024)
+       logger := device.log.Debug
+       timeout := time.NewTimer(time.Hour)
 
-               case <-peer.timer.handshakeTimeout.C:
+       logger.Println("Routine, handshake initator, started for peer", peer.id)
 
-                       // check if we can stop trying
+       func() {
+               for {
+                       var attempts uint
+                       var deadline time.Time
 
-                       if time.Now().Sub(begun) > MaxHandshakeAttempTime {
-                               peer.signal.flushNonceQueue <- true
-                               peer.timer.sendKeepalive.Stop()
-                               ongoing = false
-                               continue
+                       select {
+                       case <-peer.signal.handshakeBegin:
+                       case <-peer.signal.stop:
+                               return
                        }
 
-                       // otherwise, try again (exponental backoff)
-
-                       attempts += 1
-                       err := queueHandshakeInitiation()
-                       if err != nil {
-                               device.log.Error.Println("Failed to create initiation message:", err)
+               HandshakeLoop:
+                       for run := true; run; {
+                               // clear completed signal
+
+                               select {
+                               case <-peer.signal.handshakeCompleted:
+                               case <-peer.signal.stop:
+                                       return
+                               default:
+                               }
+
+                               // queue handshake
+
+                               err := func() error {
+                                       work.mutex.Lock()
+                                       defer work.mutex.Unlock()
+
+                                       // create initiation
+
+                                       msg, err := device.CreateMessageInitiation(peer)
+                                       if err != nil {
+                                               return err
+                                       }
+
+                                       // marshal
+
+                                       writer := bytes.NewBuffer(buffer[:0])
+                                       binary.Write(writer, binary.LittleEndian, msg)
+                                       work.packet = writer.Bytes()
+                                       peer.mac.AddMacs(work.packet)
+                                       peer.InsertOutbound(work)
+                                       return nil
+                               }()
+                               if err != nil {
+                                       device.log.Error.Println("Failed to create initiation message:", err)
+                                       break
+                               }
+                               if attempts == 0 {
+                                       deadline = time.Now().Add(MaxHandshakeAttemptTime)
+                               }
+
+                               // set timeout
+
+                               if !timeout.Stop() {
+                                       select {
+                                       case <-timeout.C:
+                                       default:
+                                       }
+                               }
+                               timeout.Reset((1 << attempts) * RekeyTimeout)
+                               attempts += 1
+                               device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
+                               time.Sleep(RekeyTimeout)
+
+                               // wait for handshake or timeout
+
+                               select {
+                               case <-peer.signal.stop:
+                                       return
+
+                               case <-peer.signal.handshakeCompleted:
+                                       break HandshakeLoop
+
+                               default:
+                                       select {
+
+                                       case <-peer.signal.stop:
+                                               return
+
+                                       case <-peer.signal.handshakeCompleted:
+                                               break HandshakeLoop
+
+                                       case <-timeout.C:
+                                               nextTimeout := (1 << attempts) * RekeyTimeout
+                                               if deadline.Before(time.Now().Add(nextTimeout)) {
+                                                       // we do not have time for another attempt
+                                                       peer.signal.flushNonceQueue <- struct{}{}
+                                                       if !peer.timer.sendKeepalive.Stop() {
+                                                               <-peer.timer.sendKeepalive.C
+                                                       }
+                                                       break HandshakeLoop
+                                               }
+                                       }
+                               }
                        }
-                       peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
                }
-       }
+       }()
+
+       logger.Println("Routine, handshake initator, stopped for peer", peer.id)
 }
 
-/* Handles packets related to handshake
+/* Handles incomming packets related to handshake
  *
  *
  */
@@ -140,33 +212,12 @@ func (device *Device) HandshakeWorker(queue chan struct {
                        // check for cookie
 
                case MessageCookieReplyType:
+                       if len(elem.msg) != MessageCookieReplySize {
+                               continue
+                       }
 
-               case MessageTransportType:
-               }
-
-       }
-}
-
-func (device *Device) KeepKeyFresh(peer *Peer) {
-
-       send := func() bool {
-               peer.keyPairs.mutex.RLock()
-               defer peer.keyPairs.mutex.RUnlock()
-
-               kp := peer.keyPairs.current
-               if kp == nil {
-                       return false
-               }
-
-               nonce := atomic.LoadUint64(&kp.sendNonce)
-               if nonce > RekeyAfterMessage {
-                       return true
+               default:
+                       device.log.Error.Println("Invalid message type in handshake queue")
                }
-
-               return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
-       }()
-
-       if send {
-
        }
 }
index 3a5c331622d0b39d330948ac5d9b212d57881144..464292fcf31c0e47ed729ee697ee8a5f451b868f 100644 (file)
@@ -35,7 +35,7 @@ func (tun *DummyTUN) Read(d []byte) (int, error) {
 
 func CreateDummyTUN(name string) (TUNDevice, error) {
        var dummy DummyTUN
-       dummy.mtu = 1024
+       dummy.mtu = 0
        dummy.packets = make(chan []byte, 100)
        return &dummy, nil
 }
@@ -58,7 +58,7 @@ func randDevice(t *testing.T) *Device {
                t.Fatal(err)
        }
        tun, _ := CreateDummyTUN("dummy")
-       device := NewDevice(tun)
+       device := NewDevice(tun, LogLevelError)
        device.SetPrivateKey(sk)
        return device
 }
index 917851056c358a4c8cf0ebea3d0fb0420964dc67..59e20793d8f74d5db5264a80a7f31a5aff48f4ba 100644 (file)
@@ -41,7 +41,7 @@ func (table *IndexTable) Init() {
        table.mutex.Unlock()
 }
 
-func (table *IndexTable) ClearIndex(index uint32) {
+func (table *IndexTable) Delete(index uint32) {
        if index == 0 {
                return
        }
index 0b029cebf6fe57567aa76a07c126616b2c203383..0e845f772e8feefed5f546b2e357307750725c75 100644 (file)
@@ -13,20 +13,27 @@ type KeyPair struct {
        sendNonce   uint64
        isInitiator bool
        created     time.Time
+       id          uint32
 }
 
 type KeyPairs struct {
-       mutex      sync.RWMutex
-       current    *KeyPair
-       previous   *KeyPair
-       next       *KeyPair  // not yet "confirmed by transport"
-       newKeyPair chan bool // signals when "current" has been updated
+       mutex    sync.RWMutex
+       current  *KeyPair
+       previous *KeyPair
+       next     *KeyPair // not yet "confirmed by transport"
 }
 
-func (kp *KeyPairs) Init() {
-       kp.mutex.Lock()
-       kp.newKeyPair = make(chan bool, 5)
-       kp.mutex.Unlock()
+/* Called during recieving to confirm the handshake
+ * was completed correctly
+ */
+func (kp *KeyPairs) Used(key *KeyPair) {
+       if key == kp.next {
+               kp.mutex.Lock()
+               kp.previous = kp.current
+               kp.current = key
+               kp.next = nil
+               kp.mutex.Unlock()
+       }
 }
 
 func (kp *KeyPairs) Current() *KeyPair {
index 117fe5bc9fd0ad3b81305913745a2cae62443fa1..827f9e9b18a70770844f684c8ebac0856b47e2e0 100644 (file)
@@ -1,6 +1,8 @@
 package main
 
 import (
+       "io"
+       "io/ioutil"
        "log"
        "os"
 )
@@ -17,17 +19,30 @@ type Logger struct {
        Error *log.Logger
 }
 
-func NewLogger() *Logger {
+func NewLogger(level int) *Logger {
+       output := os.Stdout
        logger := new(Logger)
-       logger.Debug = log.New(os.Stdout,
+
+       logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
+               if level >= LogLevelDebug {
+                       return output, output, output
+               }
+               if level >= LogLevelInfo {
+                       return output, output, ioutil.Discard
+               }
+               return output, ioutil.Discard, ioutil.Discard
+       }()
+
+       logger.Debug = log.New(logDebug,
                "DEBUG: ",
                log.Ldate|log.Ltime|log.Lshortfile,
        )
-       logger.Info = log.New(os.Stdout,
+
+       logger.Info = log.New(logInfo,
                "INFO: ",
                log.Ldate|log.Ltime|log.Lshortfile,
        )
-       logger.Error = log.New(os.Stdout,
+       logger.Error = log.New(logErr,
                "ERROR: ",
                log.Ldate|log.Ltime|log.Lshortfile,
        )
index fcb64ea8885430f0ad9208163642e13c61ac1699..a2a65035586b04885e1261dd9463bfdfca6a9b95 100644 (file)
@@ -11,6 +11,9 @@ func TestMAC1(t *testing.T) {
        dev1 := randDevice(t)
        dev2 := randDevice(t)
 
+       defer dev1.Close()
+       defer dev2.Close()
+
        peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
        peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
 
@@ -40,6 +43,9 @@ func TestMACs(t *testing.T) {
                device2 := randDevice(t)
                device2.SetPrivateKey(sk2)
 
+               defer device1.Close()
+               defer device2.Close()
+
                peer1 := device2.NewPeer(device1.privateKey.publicKey())
                peer2 := device1.NewPeer(device2.privateKey.publicKey())
 
index 9c76ff4cebd11fac76e54c895538c2e89d74050f..b89af17eaee3b4e8d8f0ed2963242bb59c0a31d1 100644 (file)
@@ -28,7 +28,7 @@ func main() {
                return
        }
 
-       device := NewDevice(tun)
+       device := NewDevice(tun, LogLevelDebug)
 
        // Start configuration lister
 
index e1244d654bf0af396fd432eb48e13d4d8baf8fcd..2bcb14884444950253bf93ef9de5129012aad7e5 100644 (file)
@@ -6,3 +6,10 @@ func min(a uint, b uint) uint {
        }
        return a
 }
+
+func sendSignal(c chan struct{}) {
+       select {
+       case c <- struct{}{}:
+       default:
+       }
+}
index e163acec07099bcf6df9a4c698e441723a62c507..1e622a5bbf2553fbfc3a82c2b3d602eeadd4f201 100644 (file)
@@ -33,6 +33,7 @@ func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
        HMAC(&prk, key, input)
        HMAC(&t0, prk[:], []byte{0x1})
        HMAC(&t1, prk[:], append(t0[:], 0x2))
+       prk = [blake2s.Size]byte{}
        return
 }
 
@@ -42,6 +43,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
        HMAC(&t0, prk[:], []byte{0x1})
        HMAC(&t1, prk[:], append(t0[:], 0x2))
        HMAC(&t2, prk[:], append(t1[:], 0x3))
+       prk = [blake2s.Size]byte{}
        return
 }
 
index 46ceeda5ab78dc8c0de76f398be5786415b4958f..a1a1c7ba980be191b3130ee157759056e1d39e65 100644 (file)
@@ -31,8 +31,9 @@ const (
 )
 
 const (
-       MessageInitiationSize = 148
-       MessageResponseSize   = 92
+       MessageInitiationSize  = 148
+       MessageResponseSize    = 92
+       MessageCookieReplySize = 64
 )
 
 /* Type is an 8-bit field, followed by 3 nul bytes,
@@ -91,16 +92,11 @@ type Handshake struct {
 }
 
 var (
-       InitalChainKey [blake2s.Size]byte
-       InitalHash     [blake2s.Size]byte
-       ZeroNonce      [chacha20poly1305.NonceSize]byte
+       InitialChainKey [blake2s.Size]byte
+       InitialHash     [blake2s.Size]byte
+       ZeroNonce       [chacha20poly1305.NonceSize]byte
 )
 
-func init() {
-       InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
-       InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
-}
-
 func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
        return KDF1(c[:], data)
 }
@@ -117,6 +113,13 @@ func (h *Handshake) mixKey(data []byte) {
        h.chainKey = mixKey(h.chainKey, data)
 }
 
+/* Do basic precomputations
+ */
+func init() {
+       InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
+       InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
+}
+
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
        handshake := &peer.handshake
        handshake.mutex.Lock()
@@ -125,28 +128,30 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        // create ephemeral key
 
        var err error
-       handshake.chainKey = InitalChainKey
-       handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
+       handshake.hash = InitialHash
+       handshake.chainKey = InitialChainKey
        handshake.localEphemeral, err = newPrivateKey()
        if err != nil {
                return nil, err
        }
 
-       device.indices.ClearIndex(handshake.localIndex)
-       handshake.localIndex, err = device.indices.NewIndex(peer)
-
        // assign index
 
-       var msg MessageInitiation
-
-       msg.Type = MessageInitiationType
-       msg.Ephemeral = handshake.localEphemeral.publicKey()
+       device.indices.Delete(handshake.localIndex)
+       handshake.localIndex, err = device.indices.NewIndex(peer)
 
        if err != nil {
                return nil, err
        }
 
-       msg.Sender = handshake.localIndex
+       handshake.mixHash(handshake.remoteStatic[:])
+
+       msg := MessageInitiation{
+               Type:      MessageInitiationType,
+               Ephemeral: handshake.localEphemeral.publicKey(),
+               Sender:    handshake.localIndex,
+       }
+
        handshake.mixKey(msg.Ephemeral[:])
        handshake.mixHash(msg.Ephemeral[:])
 
@@ -185,9 +190,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
                return nil
        }
 
-       hash := mixHash(InitalHash, device.publicKey[:])
+       hash := mixHash(InitialHash, device.publicKey[:])
        hash = mixHash(hash, msg.Ephemeral[:])
-       chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
+       chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
 
        // decrypt static key
 
@@ -278,7 +283,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        // assign index
 
        var err error
-       device.indices.ClearIndex(handshake.localIndex)
+       device.indices.Delete(handshake.localIndex)
        handshake.localIndex, err = device.indices.NewIndex(peer)
        if err != nil {
                return nil, err
@@ -420,10 +425,15 @@ func (peer *Peer) NewKeyPair() *KeyPair {
                return nil
        }
 
-       // create AEAD instances
+       // zero handshake
+
+       handshake.chainKey = [blake2s.Size]byte{}
+       handshake.localEphemeral = NoisePrivateKey{}
+       peer.handshake.state = HandshakeZeroed
 
-       var keyPair KeyPair
+       // create AEAD instances
 
+       keyPair := new(KeyPair)
        keyPair.send, _ = chacha20poly1305.New(sendKey[:])
        keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
        keyPair.sendNonce = 0
@@ -433,30 +443,32 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
        peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
                peer:      peer,
-               keyPair:   &keyPair,
+               keyPair:   keyPair,
                handshake: nil,
        })
        handshake.localIndex = 0
 
+       // start timer for keypair
+
        // rotate key pairs
 
+       kp := &peer.keyPairs
        func() {
-               kp := &peer.keyPairs
                kp.mutex.Lock()
                defer kp.mutex.Unlock()
                if isInitiator {
-                       kp.previous = peer.keyPairs.current
-                       kp.current = &keyPair
-                       kp.newKeyPair <- true
+                       if kp.previous != nil {
+                               kp.previous.send = nil
+                               kp.previous.recv = nil
+                               peer.device.indices.Delete(kp.previous.id)
+                       }
+                       kp.previous = kp.current
+                       kp.current = keyPair
+                       sendSignal(peer.signal.newKeyPair)
                } else {
-                       kp.next = &keyPair
+                       kp.next = keyPair
                }
        }()
 
-       // zero handshake
-
-       handshake.chainKey = [blake2s.Size]byte{}
-       handshake.localEphemeral = NoisePrivateKey{}
-       peer.handshake.state = HandshakeZeroed
-       return &keyPair
+       return keyPair
 }
index 02f6bf38f65addc7d57d811899876c7de09c6d79..9b50ff3bb2a9fbdc4733f658515cc9f7580c1da1 100644 (file)
@@ -25,10 +25,12 @@ func TestCurveWrappers(t *testing.T) {
 }
 
 func TestNoiseHandshake(t *testing.T) {
-
        dev1 := randDevice(t)
        dev2 := randDevice(t)
 
+       defer dev1.Close()
+       defer dev2.Close()
+
        peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
        peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
 
index 21cad9d81ea271ad6ab31657417885069da246c5..e885cee2c6b439cd2414619330769655fe6e8db9 100644 (file)
@@ -10,26 +10,29 @@ import (
 const ()
 
 type Peer struct {
+       id                          uint
        mutex                       sync.RWMutex
        endpoint                    *net.UDPAddr
-       persistentKeepaliveInterval time.Duration // 0 = disabled
+       persistentKeepaliveInterval uint64
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
        tx_bytes                    uint64
        rx_bytes                    uint64
        time                        struct {
-               lastSend time.Time // last send message
+               lastSend      time.Time // last send message
+               lastHandshake time.Time // last completed handshake
        }
        signal struct {
-               newHandshake    chan bool
-               flushNonceQueue chan bool // empty queued packets
-               stopSending     chan bool // stop sending pipeline
-               stopInitiator   chan bool // stop initiator timer
+               newKeyPair         chan struct{} // (size 1) : a new key pair was generated
+               handshakeBegin     chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
+               handshakeCompleted chan struct{} // (size 1) : handshake completed
+               flushNonceQueue    chan struct{} // (size 1) : empty queued packets
+               stop               chan struct{} // (size 0) : close to stop all goroutines for peer
        }
        timer struct {
-               sendKeepalive    time.Timer
-               handshakeTimeout time.Timer
+               sendKeepalive    *time.Timer
+               handshakeTimeout *time.Timer
        }
        queue struct {
                nonce    chan []byte                // nonce / pre-handshake queue
@@ -39,25 +42,30 @@ type Peer struct {
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
-       var peer Peer
-
        // create peer
 
+       peer := new(Peer)
        peer.mutex.Lock()
+       defer peer.mutex.Unlock()
        peer.device = device
-       peer.keyPairs.Init()
        peer.mac.Init(pk)
        peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
        peer.queue.nonce = make(chan []byte, QueueOutboundSize)
+       peer.timer.sendKeepalive = StoppedTimer()
 
-       // map public key
+       // assign id for debugging
 
        device.mutex.Lock()
+       peer.id = device.idCounter
+       device.idCounter += 1
+
+       // map public key
+
        _, ok := device.peers[pk]
        if ok {
                panic(errors.New("bug: adding existing peer"))
        }
-       device.peers[pk] = &peer
+       device.peers[pk] = peer
        device.mutex.Unlock()
 
        // precompute DH
@@ -67,22 +75,24 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        handshake.remoteStatic = pk
        handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
        handshake.mutex.Unlock()
-       peer.mutex.Unlock()
 
-       // start workers
+       // prepare signaling
+
+       peer.signal.stop = make(chan struct{})
+       peer.signal.newKeyPair = make(chan struct{}, 1)
+       peer.signal.handshakeBegin = make(chan struct{}, 1)
+       peer.signal.handshakeCompleted = make(chan struct{}, 1)
+       peer.signal.flushNonceQueue = make(chan struct{}, 1)
 
-       peer.signal.stopSending = make(chan bool, 1)
-       peer.signal.stopInitiator = make(chan bool, 1)
-       peer.signal.newHandshake = make(chan bool, 1)
-       peer.signal.flushNonceQueue = make(chan bool, 1)
+       // outbound pipeline
 
        go peer.RoutineNonce()
        go peer.RoutineHandshakeInitiator()
+       go peer.RoutineSequentialSender()
 
-       return &peer
+       return peer
 }
 
 func (peer *Peer) Close() {
-       peer.signal.stopSending <- true
-       peer.signal.stopInitiator <- true
+       close(peer.signal.stop)
 }
index ab75750f13669c63a35d80408ced5dd0151b8025..d4f9342ab9e58a74bddf95d6b4a44dc0d0ab65dc 100644 (file)
@@ -5,6 +5,8 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "net"
        "sync"
+       "sync/atomic"
+       "time"
 )
 
 /* Handles outbound flow
@@ -29,6 +31,7 @@ type QueueOutboundElement struct {
        packet  []byte
        nonce   uint64
        keyPair *KeyPair
+       peer    *Peer
 }
 
 func (peer *Peer) FlushNonceQueue() {
@@ -46,6 +49,7 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
        for {
                select {
                case peer.queue.outbound <- elem:
+                       return
                default:
                        select {
                        case <-peer.queue.outbound:
@@ -61,11 +65,15 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
  * Obs. Single instance per TUN device
  */
 func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+       if tun.MTU() == 0 {
+               // Dummy
+               return
+       }
+
        device.log.Debug.Println("Routine, TUN Reader: started")
        for {
                // read packet
 
-               device.log.Debug.Println("Read")
                packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
                size, err := tun.Read(packet)
                if err != nil {
@@ -94,13 +102,16 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 
                default:
                        device.log.Debug.Println("Receieved packet with unknown IP version")
-                       return
                }
 
                if peer == nil {
                        device.log.Debug.Println("No peer configured for IP")
                        continue
                }
+               if peer.endpoint == nil {
+                       device.log.Debug.Println("No known endpoint for peer", peer.id)
+                       continue
+               }
 
                // insert into nonce/pre-handshake queue
 
@@ -131,69 +142,95 @@ func (peer *Peer) RoutineNonce() {
        var packet []byte
        var keyPair *KeyPair
 
-       for {
+       device := peer.device
+       logger := device.log.Debug
 
-               // wait for packet
+       logger.Println("Routine, nonce worker, started for peer", peer.id)
 
-               if packet == nil {
-                       select {
-                       case packet = <-peer.queue.nonce:
-                       case <-peer.signal.stopSending:
-                               close(peer.queue.outbound)
-                               return
+       func() {
+
+               for {
+               NextPacket:
+
+                       // wait for packet
+
+                       if packet == nil {
+                               select {
+                               case packet = <-peer.queue.nonce:
+                               case <-peer.signal.stop:
+                                       return
+                               }
                        }
-               }
 
-               // wait for key pair
+                       // wait for key pair
+
+                       for {
+                               select {
+                               case <-peer.signal.newKeyPair:
+                               default:
+                               }
 
-               for keyPair == nil {
-                       peer.signal.newHandshake <- true
-                       select {
-                       case <-peer.keyPairs.newKeyPair:
                                keyPair = peer.keyPairs.Current()
-                               continue
-                       case <-peer.signal.flushNonceQueue:
-                               peer.FlushNonceQueue()
-                               packet = nil
-                               continue
-                       case <-peer.signal.stopSending:
-                               close(peer.queue.outbound)
-                               return
-                       }
-               }
+                               if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
+                                       if time.Now().Sub(keyPair.created) < RejectAfterTime {
+                                               break
+                                       }
+                               }
 
-               // process current packet
+                               sendSignal(peer.signal.handshakeBegin)
+                               logger.Println("Waiting for key-pair, peer", peer.id)
 
-               if packet != nil {
+                               select {
+                               case <-peer.signal.newKeyPair:
+                                       logger.Println("Key-pair negotiated for peer", peer.id)
+                                       goto NextPacket
+
+                               case <-peer.signal.flushNonceQueue:
+                                       logger.Println("Clearing queue for peer", peer.id)
+                                       peer.FlushNonceQueue()
+                                       packet = nil
+                                       goto NextPacket
+
+                               case <-peer.signal.stop:
+                                       return
+                               }
+                       }
 
-                       // create work element
+                       // process current packet
 
-                       work := new(QueueOutboundElement) // TODO: profile, maybe use pool
-                       work.keyPair = keyPair
-                       work.packet = packet
-                       work.nonce = keyPair.sendNonce
-                       work.mutex.Lock()
+                       if packet != nil {
 
-                       packet = nil
-                       keyPair.sendNonce += 1
+                               // create work element
 
-                       // drop packets until there is space
+                               work := new(QueueOutboundElement) // TODO: profile, maybe use pool
+                               work.keyPair = keyPair
+                               work.packet = packet
+                               work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1)
+                               work.peer = peer
+                               work.mutex.Lock()
 
-                       func() {
-                               for {
-                                       select {
-                                       case peer.device.queue.encryption <- work:
-                                               return
-                                       default:
-                                               drop := <-peer.device.queue.encryption
-                                               drop.packet = nil
-                                               drop.mutex.Unlock()
+                               packet = nil
+
+                               // drop packets until there is space
+
+                               func() {
+                                       for {
+                                               select {
+                                               case peer.device.queue.encryption <- work:
+                                                       return
+                                               default:
+                                                       drop := <-peer.device.queue.encryption
+                                                       drop.packet = nil
+                                                       drop.mutex.Unlock()
+                                               }
                                        }
-                               }
-                       }()
-                       peer.queue.outbound <- work
+                               }()
+                               peer.queue.outbound <- work
+                       }
                }
-       }
+       }()
+
+       logger.Println("Routine, nonce worker, stopped for peer", peer.id)
 }
 
 /* Encrypts the elements in the queue
@@ -227,6 +264,10 @@ func (device *Device) RoutineEncryption() {
                        nil,
                )
                work.mutex.Unlock()
+
+               // initiate new handshake
+
+               work.peer.KeepKeyFreshSending()
        }
 }
 
@@ -235,21 +276,54 @@ func (device *Device) RoutineEncryption() {
  * Obs. Single instance per peer.
  * The routine terminates then the outbound queue is closed.
  */
-func (peer *Peer) RoutineSequential() {
-       for work := range peer.queue.outbound {
-               work.mutex.Lock()
-               func() {
-                       peer.mutex.RLock()
-                       defer peer.mutex.RUnlock()
-                       if work.packet == nil {
-                               return
-                       }
-                       if peer.endpoint == nil {
-                               return
-                       }
-                       peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
-                       peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
-               }()
-               work.mutex.Unlock()
+func (peer *Peer) RoutineSequentialSender() {
+       logger := peer.device.log.Debug
+       logger.Println("Routine, sequential sender, started for peer", peer.id)
+
+       device := peer.device
+
+       for {
+               select {
+               case <-peer.signal.stop:
+                       logger.Println("Routine, sequential sender, stopped for peer", peer.id)
+                       return
+               case work := <-peer.queue.outbound:
+                       work.mutex.Lock()
+                       func() {
+                               if work.packet == nil {
+                                       return
+                               }
+
+                               peer.mutex.RLock()
+                               defer peer.mutex.RUnlock()
+
+                               if peer.endpoint == nil {
+                                       logger.Println("No endpoint for peer:", peer.id)
+                                       return
+                               }
+
+                               device.net.mutex.RLock()
+                               defer device.net.mutex.RUnlock()
+
+                               if device.net.conn == nil {
+                                       logger.Println("No source for device")
+                                       return
+                               }
+
+                               logger.Println("Sending packet for peer", peer.id, work.packet)
+
+                               _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
+                               logger.Println("SEND:", peer.endpoint, err)
+                               atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
+
+                               // shift keep-alive timer
+
+                               if peer.persistentKeepaliveInterval != 0 {
+                                       interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second
+                                       peer.timer.sendKeepalive.Reset(interval)
+                               }
+                       }()
+                       work.mutex.Unlock()
+               }
        }
 }
index cbbcb70812d3de606a83bfb8e4d1508dede7267f..db13fb090670e64b41a4e2d5b8403bf71451b8f1 100644 (file)
@@ -74,5 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
        return &NativeTun{
                fd:   fd,
                name: newName,
+               mtu:  1500, // TODO: FIX
        }, nil
 }