]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Work on UAPI
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Wed, 28 Jun 2017 21:45:45 +0000 (23:45 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Wed, 28 Jun 2017 21:45:45 +0000 (23:45 +0200)
Cross-platform API (get operation)
Handshake initiation creation process
Outbound packet flow
Fixes from code-review

18 files changed:
src/Makefile [new file with mode: 0644]
src/config.go
src/constants.go
src/device.go
src/handshake.go [new file with mode: 0644]
src/helper_test.go [new file with mode: 0644]
src/ip.go
src/macs_test.go
src/main.go
src/noise_protocol.go
src/noise_test.go
src/noise_types.go
src/peer.go
src/routing.go
src/send.go
src/trie.go
src/tun.go
src/tun_linux.go

diff --git a/src/Makefile b/src/Makefile
new file mode 100644 (file)
index 0000000..4ef8199
--- /dev/null
@@ -0,0 +1,9 @@
+BINARY=wireguard-go
+
+build:
+       go build -o ${BINARY}
+
+clean:
+       if [ -f ${BINARY} ]; then rm ${BINARY}; fi
+
+.PHONY: clean
index cb7e9efe8a09cddd59f5c2a2619802b1536aa041..3b91d00e6e1cf79319913b0b7cba08dbacc6efa7 100644 (file)
@@ -11,7 +11,7 @@ import (
        "time"
 )
 
-/* todo : use real error code
+/* TODO : use real error code
  * Many of which will be the same
  */
 const (
@@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
        return s.Code
 }
 
-func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
+func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
 
+       device.mutex.RLock()
+       defer device.mutex.RUnlock()
+
+       // create lines
+
+       lines := make([]string, 0, 100)
+       send := func(line string) {
+               lines = append(lines, line)
+       }
+
+       if !device.privateKey.IsZero() {
+               send("private_key=" + device.privateKey.ToHex())
+       }
+
+       if device.address != nil {
+               send(fmt.Sprintf("listen_port=%d", device.address.Port))
+       }
+
+       for _, peer := range device.peers {
+               func() {
+                       peer.mutex.RLock()
+                       defer peer.mutex.RUnlock()
+                       send("public_key=" + peer.handshake.remoteStatic.ToHex())
+                       send("preshared_key=" + peer.handshake.presharedKey.ToHex())
+                       if peer.endpoint != nil {
+                               send("endpoint=" + peer.endpoint.String())
+                       }
+                       send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
+                       send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
+                       send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+                       for _, ip := range device.routingTable.AllowedIPs(peer) {
+                               send("allowed_ip=" + ip.String())
+                       }
+               }()
+       }
+
+       // send lines
+
+       for _, line := range lines {
+               device.log.Debug.Println("config:", line)
+               _, err := socket.WriteString(line + "\n")
+               if err != nil {
+                       return err
+               }
+       }
+
+       return nil
 }
 
 func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
@@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        return nil
 }
 
-func ipcListen(dev *Device, socket io.ReadWriter) error {
+func ipcListen(device *Device, socket io.ReadWriter) error {
 
        buffered := func(s io.ReadWriter) *bufio.ReadWriter {
                reader := bufio.NewReader(s)
@@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
                return bufio.NewReadWriter(reader, writer)
        }(socket)
 
+       defer buffered.Flush()
+
        for {
                op, err := buffered.ReadString('\n')
                if err != nil {
@@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
                switch op {
 
                case "set=1\n":
-                       err := ipcSetOperation(dev, buffered)
+                       err := ipcSetOperation(device, buffered)
                        if err != nil {
-                               fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
+                               fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
                                return err
                        } else {
-                               fmt.Fprintf(buffered, "errno=0\n")
+                               fmt.Fprintf(buffered, "errno=0\n\n")
                        }
                        buffered.Flush()
 
                case "get=1\n":
+                       err := ipcGetOperation(device, buffered)
+                       if err != nil {
+                               fmt.Fprintf(buffered, "errno=1\n\n") // fix
+                               return err
+                       } else {
+                               fmt.Fprintf(buffered, "errno=0\n\n")
+                       }
+                       buffered.Flush()
 
+               case "\n":
                default:
                        return errors.New("handle this please")
                }
index dc9537946a57103b924dc8d8164e42909192cdf1..e8cdd63e4645d7f83e0a74ff9f7a1ded5fd74930 100644 (file)
@@ -5,12 +5,17 @@ import (
 )
 
 const (
-       RekeyAfterMessage  = (1 << 64) - (1 << 16) - 1
-       RekeyAfterTime     = time.Second * 120
-       RekeyAttemptTime   = time.Second * 90
-       RekeyTimeout       = time.Second * 5
-       RejectAfterTime    = time.Second * 180
-       RejectAfterMessage = (1 << 64) - (1 << 4) - 1
-       KeepaliveTimeout   = time.Second * 10
-       CookieRefreshTime  = time.Second * 2
+       RekeyAfterMessage      = (1 << 64) - (1 << 16) - 1
+       RekeyAfterTime         = time.Second * 120
+       RekeyAttemptTime       = time.Second * 90
+       RekeyTimeout           = time.Second * 5 // TODO: Exponential backoff
+       RejectAfterTime        = time.Second * 180
+       RejectAfterMessage     = (1 << 64) - (1 << 4) - 1
+       KeepaliveTimeout       = time.Second * 10
+       CookieRefreshTime      = time.Second * 2
+       MaxHandshakeAttempTime = time.Second * 90
+)
+
+const (
+       QueueOutboundSize = 1024
 )
index b3484c5fe61e58d1e53041b0f9329e4ce763cdf6..a7a5c7bc12ef7bf3db671cce95c505df90b73242 100644 (file)
@@ -2,23 +2,26 @@ package main
 
 import (
        "net"
+       "runtime"
        "sync"
 )
 
 type Device struct {
-       mtu               int
-       fwMark            uint32
-       address           *net.UDPAddr // UDP source address
-       conn              *net.UDPConn // UDP "connection"
-       mutex             sync.RWMutex
-       privateKey        NoisePrivateKey
-       publicKey         NoisePublicKey
-       routingTable      RoutingTable
-       indices           IndexTable
-       log               *Logger
-       queueWorkOutbound chan *OutboundWorkQueueElement
-       peers             map[NoisePublicKey]*Peer
-       mac               MacStateDevice
+       mtu          int
+       fwMark       uint32
+       address      *net.UDPAddr // UDP source address
+       conn         *net.UDPConn // UDP "connection"
+       mutex        sync.RWMutex
+       privateKey   NoisePrivateKey
+       publicKey    NoisePublicKey
+       routingTable RoutingTable
+       indices      IndexTable
+       log          *Logger
+       queue        struct {
+               encryption chan *QueueOutboundElement // parallel work queue
+       }
+       peers map[NoisePublicKey]*Peer
+       mac   MacStateDevice
 }
 
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
@@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
        }
 }
 
-func (device *Device) Init() {
+func NewDevice(tun TUNDevice) *Device {
+       device := new(Device)
+
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
@@ -49,6 +54,14 @@ func (device *Device) Init() {
        device.peers = make(map[NoisePublicKey]*Peer)
        device.indices.Init()
        device.routingTable.Reset()
+
+       // start workers
+
+       for i := 0; i < runtime.NumCPU(); i += 1 {
+               go device.RoutineEncryption()
+       }
+       go device.RoutineReadFromTUN(tun)
+       return device
 }
 
 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
diff --git a/src/handshake.go b/src/handshake.go
new file mode 100644 (file)
index 0000000..238c339
--- /dev/null
@@ -0,0 +1,172 @@
+package main
+
+import (
+       "bytes"
+       "encoding/binary"
+       "net"
+       "sync/atomic"
+       "time"
+)
+
+/* Sends a keep-alive if no packets queued for peer
+ *
+ * Used by initiator of handshake and with active keep-alive
+ */
+func (peer *Peer) SendKeepAlive() bool {
+       if len(peer.queue.nonce) == 0 {
+               select {
+               case peer.queue.nonce <- []byte{}:
+                       return true
+               default:
+                       return false
+               }
+       }
+       return true
+}
+
+func (peer *Peer) RoutineHandshakeInitiator() {
+       var ongoing bool
+       var begun time.Time
+       var attempts uint
+       var timeout time.Timer
+
+       device := peer.device
+       work := new(QueueOutboundElement)
+       buffer := make([]byte, 0, 1024)
+
+       queueHandshakeInitiation := func() error {
+               work.mutex.Lock()
+               defer work.mutex.Unlock()
+
+               // create initiation
+
+               msg, err := device.CreateMessageInitiation(peer)
+               if err != nil {
+                       return err
+               }
+
+               // create "work" element
+
+               writer := bytes.NewBuffer(buffer[:0])
+               binary.Write(writer, binary.LittleEndian, &msg)
+               work.packet = writer.Bytes()
+               peer.mac.AddMacs(work.packet)
+               peer.InsertOutbound(work)
+               return nil
+       }
+
+       for {
+               select {
+               case <-peer.signal.stopInitiator:
+                       return
+
+               case <-peer.signal.newHandshake:
+                       if ongoing {
+                               continue
+                       }
+
+                       // create handshake
+
+                       err := queueHandshakeInitiation()
+                       if err != nil {
+                               device.log.Error.Println("Failed to create initiation message:", err)
+                       }
+
+                       // log when we began
+
+                       begun = time.Now()
+                       ongoing = true
+                       attempts = 0
+                       timeout.Reset(RekeyTimeout)
+
+               case <-peer.timer.sendKeepalive.C:
+
+                       // active keep-alives
+
+                       peer.SendKeepAlive()
+
+               case <-peer.timer.handshakeTimeout.C:
+
+                       // check if we can stop trying
+
+                       if time.Now().Sub(begun) > MaxHandshakeAttempTime {
+                               peer.signal.flushNonceQueue <- true
+                               peer.timer.sendKeepalive.Stop()
+                               ongoing = false
+                               continue
+                       }
+
+                       // otherwise, try again (exponental backoff)
+
+                       attempts += 1
+                       err := queueHandshakeInitiation()
+                       if err != nil {
+                               device.log.Error.Println("Failed to create initiation message:", err)
+                       }
+                       peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
+               }
+       }
+}
+
+/* Handles packets related to handshake
+ *
+ *
+ */
+func (device *Device) HandshakeWorker(queue chan struct {
+       msg     []byte
+       msgType uint32
+       addr    *net.UDPAddr
+}) {
+       for {
+               elem := <-queue
+
+               switch elem.msgType {
+               case MessageInitiationType:
+                       if len(elem.msg) != MessageInitiationSize {
+                               continue
+                       }
+
+                       // check for cookie
+
+                       var msg MessageInitiation
+
+                       binary.Read(nil, binary.LittleEndian, &msg)
+
+               case MessageResponseType:
+                       if len(elem.msg) != MessageResponseSize {
+                               continue
+                       }
+
+                       // check for cookie
+
+               case MessageCookieReplyType:
+
+               case MessageTransportType:
+               }
+
+       }
+}
+
+func (device *Device) KeepKeyFresh(peer *Peer) {
+
+       send := func() bool {
+               peer.keyPairs.mutex.RLock()
+               defer peer.keyPairs.mutex.RUnlock()
+
+               kp := peer.keyPairs.current
+               if kp == nil {
+                       return false
+               }
+
+               nonce := atomic.LoadUint64(&kp.sendNonce)
+               if nonce > RekeyAfterMessage {
+                       return true
+               }
+
+               return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
+       }()
+
+       if send {
+
+       }
+}
diff --git a/src/helper_test.go b/src/helper_test.go
new file mode 100644 (file)
index 0000000..3a5c331
--- /dev/null
@@ -0,0 +1,64 @@
+package main
+
+import (
+       "bytes"
+       "testing"
+)
+
+/* Helpers for writing unit tests
+ */
+
+type DummyTUN struct {
+       name    string
+       mtu     uint
+       packets chan []byte
+}
+
+func (tun *DummyTUN) Name() string {
+       return tun.name
+}
+
+func (tun *DummyTUN) MTU() uint {
+       return tun.mtu
+}
+
+func (tun *DummyTUN) Write(d []byte) (int, error) {
+       tun.packets <- d
+       return len(d), nil
+}
+
+func (tun *DummyTUN) Read(d []byte) (int, error) {
+       t := <-tun.packets
+       copy(d, t)
+       return len(t), nil
+}
+
+func CreateDummyTUN(name string) (TUNDevice, error) {
+       var dummy DummyTUN
+       dummy.mtu = 1024
+       dummy.packets = make(chan []byte, 100)
+       return &dummy, nil
+}
+
+func assertNil(t *testing.T, err error) {
+       if err != nil {
+               t.Fatal(err)
+       }
+}
+
+func assertEqual(t *testing.T, a []byte, b []byte) {
+       if bytes.Compare(a, b) != 0 {
+               t.Fatal(a, "!=", b)
+       }
+}
+
+func randDevice(t *testing.T) *Device {
+       sk, err := newPrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+       tun, _ := CreateDummyTUN("dummy")
+       device := NewDevice(tun)
+       device.SetPrivateKey(sk)
+       return device
+}
index 3137891283b4034aade7e15d85b84671180d9f71..a9685adba450dc5d2b3a39ae4a79bc2399cc9359 100644 (file)
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,9 +5,10 @@ import (
 )
 
 const (
-       IPv4version   = 4
-       IPv4offsetSrc = 12
-       IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+       IPv4version    = 4
+       IPv4offsetSrc  = 12
+       IPv4offsetDst  = IPv4offsetSrc + net.IPv4len
+       IPv4headerSize = 20
 )
 
 const (
index a67ccfba7ee4df2434c31cad4bf236440df78e6e..fcb64ea8885430f0ad9208163642e13c61ac1699 100644 (file)
@@ -8,8 +8,8 @@ import (
 )
 
 func TestMAC1(t *testing.T) {
-       dev1 := newDevice(t)
-       dev2 := newDevice(t)
+       dev1 := randDevice(t)
+       dev2 := randDevice(t)
 
        peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
        peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
                msg []byte,
                receiver uint32,
        ) bool {
-               var device1 Device
-               device1.Init()
+               device1 := randDevice(t)
                device1.SetPrivateKey(sk1)
 
-               var device2 Device
-               device2.Init()
+               device2 := randDevice(t)
                device2.SetPrivateKey(sk2)
 
                peer1 := device2.NewPeer(device1.privateKey.publicKey())
index b6f6deb92bb698d776aecd7874d35124ae771bc9..7c589721ee115ec9c816bd91a8c8254504d4fe26 100644 (file)
@@ -1,36 +1,30 @@
 package main
 
 import (
-       "fmt"
+       "log"
+       "net"
 )
 
+/*
+ *
+ * TODO: Fix logging
+ */
+
 func main() {
-       fd, err := CreateTUN("test0")
-       fmt.Println(fd, err)
+       // Open TUN device
 
-       queue := make(chan []byte, 1000)
+       // TODO: Fix capabilities
 
-       // var device Device
+       tun, err := CreateTUN("test0")
+       log.Println(tun, err)
+       if err != nil {
+               return
+       }
 
-       // go OutgoingRoutingWorker(&device, queue)
+       device := NewDevice(tun)
 
-       for {
-               tmp := make([]byte, 1<<16)
-               n, err := fd.Read(tmp)
-               if err != nil {
-                       break
-               }
-               queue <- tmp[:n]
-       }
-}
+       // Start configuration lister
 
-/*
-import (
-       "fmt"
-       "log"
-       "net"
-)
-func main() {
        l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
        if err != nil {
                log.Fatal("listen error:", err)
@@ -41,12 +35,9 @@ func main() {
                if err != nil {
                        log.Fatal("accept error:", err)
                }
-
-               var dev Device
                go func(conn net.Conn) {
-                       err := ipcListen(&dev, conn)
-                       fmt.Println(err)
+                       err := ipcListen(device, conn)
+                       log.Println(err)
                }(fd)
        }
 }
-*/
index e237dbe61873816ecabccf18dfaf6a07341bd358..46ceeda5ab78dc8c0de76f398be5786415b4958f 100644 (file)
@@ -77,7 +77,7 @@ type MessageCookieReply struct {
 
 type Handshake struct {
        state                   int
-       mutex                   sync.Mutex
+       mutex                   sync.RWMutex
        hash                    [blake2s.Size]byte       // hash value
        chainKey                [blake2s.Size]byte       // chain key
        presharedKey            NoiseSymmetricKey        // psk
@@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        }
        hash = mixHash(hash, msg.Static[:])
 
-       // find peer
+       // lookup peer
 
        peer := device.LookupPeer(peerPK)
        if peer == nil {
                return nil
        }
        handshake := &peer.handshake
-       handshake.mutex.Lock()
-       defer handshake.mutex.Unlock()
 
-       // decrypt timestamp
+       // verify identity
 
        var timestamp TAI64N
-       func() {
-               var key [chacha20poly1305.KeySize]byte
-               chainKey, key = KDF2(
-                       chainKey[:],
-                       handshake.precomputedStaticStatic[:],
-               )
-               aead, _ := chacha20poly1305.New(key[:])
-               _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
-       }()
-       if err != nil {
-               return nil
-       }
-       hash = mixHash(hash, msg.Timestamp[:])
+       ok := func() bool {
+
+               // read lock handshake
+
+               handshake.mutex.RLock()
+               defer handshake.mutex.RUnlock()
+
+               // decrypt timestamp
+
+               func() {
+                       var key [chacha20poly1305.KeySize]byte
+                       chainKey, key = KDF2(
+                               chainKey[:],
+                               handshake.precomputedStaticStatic[:],
+                       )
+                       aead, _ := chacha20poly1305.New(key[:])
+                       _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+               }()
+               if err != nil {
+                       return false
+               }
+               hash = mixHash(hash, msg.Timestamp[:])
+
+               // TODO: check for flood attack
+
+               // check for replay attack
 
-       // check for replay attack
+               return timestamp.After(handshake.lastTimestamp)
+       }()
 
-       if !timestamp.After(handshake.lastTimestamp) {
+       if !ok {
                return nil
        }
 
-       // TODO: check for flood attack
-
        // update handshake state
 
+       handshake.mutex.Lock()
+
        handshake.hash = hash
        handshake.chainKey = chainKey
        handshake.remoteIndex = msg.Sender
        handshake.remoteEphemeral = msg.Ephemeral
        handshake.lastTimestamp = timestamp
        handshake.state = HandshakeInitiationConsumed
+
+       handshake.mutex.Unlock()
+
        return peer
 }
 
@@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                return nil
        }
 
-       handshake.mutex.Lock()
-       defer handshake.mutex.Unlock()
-       if handshake.state != HandshakeInitiationCreated {
-               return nil
-       }
+       var (
+               hash     [blake2s.Size]byte
+               chainKey [blake2s.Size]byte
+       )
 
-       // finish 3-way DH
+       ok := func() bool {
 
-       hash := mixHash(handshake.hash, msg.Ephemeral[:])
-       chainKey := handshake.chainKey
+               // read lock handshake
 
-       func() {
-               ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
-               chainKey = mixKey(chainKey, ss[:])
-               ss = device.privateKey.sharedSecret(msg.Ephemeral)
-               chainKey = mixKey(chainKey, ss[:])
-       }()
+               handshake.mutex.RLock()
+               defer handshake.mutex.RUnlock()
 
-       // add preshared key (psk)
+               if handshake.state != HandshakeInitiationCreated {
+                       return false
+               }
 
-       var tau [blake2s.Size]byte
-       var key [chacha20poly1305.KeySize]byte
-       chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
-       hash = mixHash(hash, tau[:])
+               // finish 3-way DH
 
-       // authenticate
+               hash = mixHash(handshake.hash, msg.Ephemeral[:])
+               chainKey = handshake.chainKey
 
-       aead, _ := chacha20poly1305.New(key[:])
-       _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
-       if err != nil {
+               func() {
+                       ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+                       chainKey = mixKey(chainKey, ss[:])
+                       ss = device.privateKey.sharedSecret(msg.Ephemeral)
+                       chainKey = mixKey(chainKey, ss[:])
+               }()
+
+               // add preshared key (psk)
+
+               var tau [blake2s.Size]byte
+               var key [chacha20poly1305.KeySize]byte
+               chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+               hash = mixHash(hash, tau[:])
+
+               // authenticate
+
+               aead, _ := chacha20poly1305.New(key[:])
+               _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+               if err != nil {
+                       return false
+               }
+               hash = mixHash(hash, msg.Empty[:])
+               return true
+       }()
+
+       if !ok {
                return nil
        }
-       hash = mixHash(hash, msg.Empty[:])
 
        // update handshake state
 
+       handshake.mutex.Lock()
+
        handshake.hash = hash
        handshake.chainKey = chainKey
        handshake.remoteIndex = msg.Sender
        handshake.state = HandshakeResponseConsumed
 
+       handshake.mutex.Unlock()
+
        return lookup.peer
 }
 
index dab603b6141c925a688ddd15c0c52fb377a3ae29..02f6bf38f65addc7d57d811899876c7de09c6d79 100644 (file)
@@ -6,29 +6,6 @@ import (
        "testing"
 )
 
-func assertNil(t *testing.T, err error) {
-       if err != nil {
-               t.Fatal(err)
-       }
-}
-
-func assertEqual(t *testing.T, a []byte, b []byte) {
-       if bytes.Compare(a, b) != 0 {
-               t.Fatal(a, "!=", b)
-       }
-}
-
-func newDevice(t *testing.T) *Device {
-       var device Device
-       sk, err := newPrivateKey()
-       if err != nil {
-               t.Fatal(err)
-       }
-       device.Init()
-       device.SetPrivateKey(sk)
-       return &device
-}
-
 func TestCurveWrappers(t *testing.T) {
        sk1, err := newPrivateKey()
        assertNil(t, err)
@@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
 
 func TestNoiseHandshake(t *testing.T) {
 
-       dev1 := newDevice(t)
-       dev2 := newDevice(t)
+       dev1 := randDevice(t)
+       dev2 := randDevice(t)
 
        peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
        peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
index 5508f9a526f138659eefe3f27baff9283cb81921..5ebc130c94b9d3267073949f25d6e3a1bbb452af 100644 (file)
@@ -3,18 +3,18 @@ package main
 import (
        "encoding/hex"
        "errors"
+       "golang.org/x/crypto/chacha20poly1305"
 )
 
 const (
-       NoisePublicKeySize    = 32
-       NoisePrivateKeySize   = 32
-       NoiseSymmetricKeySize = 32
+       NoisePublicKeySize  = 32
+       NoisePrivateKeySize = 32
 )
 
 type (
        NoisePublicKey    [NoisePublicKeySize]byte
        NoisePrivateKey   [NoisePrivateKeySize]byte
-       NoiseSymmetricKey [NoiseSymmetricKeySize]byte
+       NoiseSymmetricKey [chacha20poly1305.KeySize]byte
        NoiseNonce        uint64 // padded to 12-bytes
 )
 
@@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
        return nil
 }
 
+func (key NoisePrivateKey) IsZero() bool {
+       for _, b := range key[:] {
+               if b != 0 {
+                       return false
+               }
+       }
+       return true
+}
+
 func (key *NoisePrivateKey) FromHex(src string) error {
        return loadExactHex(key[:], src)
 }
index e192b12d19720388086e7948fc6aa44447523ffb..21cad9d81ea271ad6ab31657417885069da246c5 100644 (file)
@@ -7,9 +7,7 @@ import (
        "time"
 )
 
-const (
-       OutboundQueueSize = 64
-)
+const ()
 
 type Peer struct {
        mutex                       sync.RWMutex
@@ -18,10 +16,26 @@ type Peer struct {
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
-       queueInbound                chan []byte
-       queueOutbound               chan *OutboundWorkQueueElement
-       queueOutboundRouting        chan []byte
-       mac                         MacStatePeer
+       tx_bytes                    uint64
+       rx_bytes                    uint64
+       time                        struct {
+               lastSend time.Time // last send message
+       }
+       signal struct {
+               newHandshake    chan bool
+               flushNonceQueue chan bool // empty queued packets
+               stopSending     chan bool // stop sending pipeline
+               stopInitiator   chan bool // stop initiator timer
+       }
+       timer struct {
+               sendKeepalive    time.Timer
+               handshakeTimeout time.Timer
+       }
+       queue struct {
+               nonce    chan []byte                // nonce / pre-handshake queue
+               outbound chan *QueueOutboundElement // sequential ordering of work
+       }
+       mac MacStatePeer
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
@@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.device = device
        peer.keyPairs.Init()
        peer.mac.Init(pk)
-       peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
+       peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+       peer.queue.nonce = make(chan []byte, QueueOutboundSize)
 
        // map public key
 
@@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        handshake.mutex.Unlock()
        peer.mutex.Unlock()
 
+       // start workers
+
+       peer.signal.stopSending = make(chan bool, 1)
+       peer.signal.stopInitiator = make(chan bool, 1)
+       peer.signal.newHandshake = make(chan bool, 1)
+       peer.signal.flushNonceQueue = make(chan bool, 1)
+
+       go peer.RoutineNonce()
+       go peer.RoutineHandshakeInitiator()
+
        return &peer
 }
+
+func (peer *Peer) Close() {
+       peer.signal.stopSending <- true
+       peer.signal.stopInitiator <- true
+}
index 4189c2582d0b8bedd41c0c897f527efb9015087d..6a5e1f36f0c7ca1ad97314687223337ff0acb66f 100644 (file)
@@ -12,9 +12,20 @@ type RoutingTable struct {
        mutex sync.RWMutex
 }
 
+func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
+       table.mutex.RLock()
+       defer table.mutex.RUnlock()
+
+       allowed := make([]net.IPNet, 10)
+       table.IPv4.AllowedIPs(peer, allowed)
+       table.IPv6.AllowedIPs(peer, allowed)
+       return allowed
+}
+
 func (table *RoutingTable) Reset() {
        table.mutex.Lock()
        defer table.mutex.Unlock()
+
        table.IPv4 = nil
        table.IPv6 = nil
 }
@@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
 func (table *RoutingTable) RemovePeer(peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
+
        table.IPv4 = table.IPv4.RemovePeer(peer)
        table.IPv6 = table.IPv6.RemovePeer(peer)
 }
index f58d311ac28650779b0b96fd33fe34ed4c7f085e..4ff75db94655f74389df5cfc3d588bc1b8c1858b 100644 (file)
@@ -5,107 +5,159 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "net"
        "sync"
-       "time"
 )
 
 /* Handles outbound flow
  *
  * 1. TUN queue
- * 2. Routing
- * 3. Per peer queuing
- * 4. (work queuing)
+ * 2. Routing (sequential)
+ * 3. Nonce assignment (sequential)
+ * 4. Encryption (parallel)
+ * 5. Transmission (sequential)
  *
+ * The order of packets (per peer) is maintained.
+ * The functions in this file occure (roughly) in the order packets are processed.
  */
 
-type OutboundWorkQueueElement struct {
-       wg      sync.WaitGroup
+/* A work unit
+ *
+ * The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work on the packet.
+ */
+type QueueOutboundElement struct {
+       mutex   sync.Mutex
        packet  []byte
        nonce   uint64
        keyPair *KeyPair
 }
 
-func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
-
+func (peer *Peer) FlushNonceQueue() {
+       elems := len(peer.queue.nonce)
+       for i := 0; i < elems; i += 1 {
+               select {
+               case <-peer.queue.nonce:
+               default:
+                       return
+               }
+       }
 }
 
-func (device *Device) SendPacket(packet []byte) {
+func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
+       for {
+               select {
+               case peer.queue.outbound <- elem:
+               default:
+                       select {
+                       case <-peer.queue.outbound:
+                       default:
+                       }
+               }
+       }
+}
 
-       // lookup peer
+/* Reads packets from the TUN and inserts
+ * into nonce queue for peer
+ *
+ * Obs. Single instance per TUN device
+ */
+func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+       for {
+               // read packet
 
-       var peer *Peer
-       switch packet[0] >> 4 {
-       case IPv4version:
-               dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
-               peer = device.routingTable.LookupIPv4(dst)
+               packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
+               size, err := tun.Read(packet)
+               if err != nil {
+                       device.log.Error.Println("Failed to read packet from TUN device:", err)
+                       continue
+               }
+               packet = packet[:size]
+               if len(packet) < IPv4headerSize {
+                       device.log.Error.Println("Packet too short, length:", len(packet))
+                       continue
+               }
 
-       case IPv6version:
-               dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
-               peer = device.routingTable.LookupIPv6(dst)
+               device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
 
-       default:
-               device.log.Debug.Println("receieved packet with unknown IP version")
-               return
-       }
+               // lookup peer
 
-       if peer == nil {
-               return
-       }
+               var peer *Peer
+               switch packet[0] >> 4 {
+               case IPv4version:
+                       dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+                       peer = device.routingTable.LookupIPv4(dst)
 
-       // insert into peer queue
+               case IPv6version:
+                       dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+                       peer = device.routingTable.LookupIPv6(dst)
 
-       for {
-               select {
-               case peer.queueOutboundRouting <- packet:
                default:
+                       device.log.Debug.Println("Receieved packet with unknown IP version")
+                       return
+               }
+
+               if peer == nil {
+                       device.log.Debug.Println("No peer configured for IP")
+                       return
+               }
+
+               // insert into nonce/pre-handshake queue
+
+               for {
                        select {
-                       case <-peer.queueOutboundRouting:
+                       case peer.queue.nonce <- packet:
                        default:
+                               select {
+                               case <-peer.queue.nonce:
+                               default:
+                               }
+                               continue
                        }
-                       continue
+                       break
                }
-               break
        }
 }
 
-/* Go routine
+/* Queues packets when there is no handshake.
+ * Then assigns nonces to packets sequentially
+ * and creates "work" structs for workers
  *
+ * TODO: Avoid dynamic allocation of work queue elements
  *
- * 1. waits for handshake.
- * 2. assigns key pair & nonce
- * 3. inserts to working queue
- *
- * TODO: avoid dynamic allocation of work queue elements
+ * Obs. A single instance per peer
  */
-func (peer *Peer) RoutineOutboundNonceWorker() {
+func (peer *Peer) RoutineNonce() {
        var packet []byte
        var keyPair *KeyPair
-       var flushTimer time.Timer
 
        for {
 
                // wait for packet
 
                if packet == nil {
-                       packet = <-peer.queueOutboundRouting
+                       select {
+                       case packet = <-peer.queue.nonce:
+                       case <-peer.signal.stopSending:
+                               close(peer.queue.outbound)
+                               return
+                       }
                }
 
                // wait for key pair
 
                for keyPair == nil {
-                       flushTimer.Reset(time.Second * 10)
-                       // TODO: Handshake or NOP
+                       peer.signal.newHandshake <- true
                        select {
                        case <-peer.keyPairs.newKeyPair:
                                keyPair = peer.keyPairs.Current()
                                continue
-                       case <-flushTimer.C:
-                               size := len(peer.queueOutboundRouting)
-                               for i := 0; i < size; i += 1 {
-                                       <-peer.queueOutboundRouting
-                               }
+                       case <-peer.signal.flushNonceQueue:
+                               peer.FlushNonceQueue()
                                packet = nil
+                               continue
+                       case <-peer.signal.stopSending:
+                               close(peer.queue.outbound)
+                               return
                        }
-                       break
                }
 
                // process current packet
@@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
 
                        // create work element
 
-                       work := new(OutboundWorkQueueElement)
-                       work.wg.Add(1)
+                       work := new(QueueOutboundElement) // TODO: profile, maybe use pool
                        work.keyPair = keyPair
                        work.packet = packet
                        work.nonce = keyPair.sendNonce
+                       work.mutex.Lock()
 
                        packet = nil
-                       peer.queueOutbound <- work
                        keyPair.sendNonce += 1
 
                        // drop packets until there is space
@@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
                        func() {
                                for {
                                        select {
-                                       case peer.device.queueWorkOutbound <- work:
+                                       case peer.device.queue.encryption <- work:
                                                return
                                        default:
-                                               drop := <-peer.device.queueWorkOutbound
+                                               drop := <-peer.device.queue.encryption
                                                drop.packet = nil
-                                               drop.wg.Done()
+                                               drop.mutex.Unlock()
                                        }
                                }
                        }()
+                       peer.queue.outbound <- work
                }
        }
 }
 
-/* Go routine
- *
- * sequentially reads packets from queue and sends to endpoint
+/* Encrypts the elements in the queue
+ * and marks them for sequential consumption (by releasing the mutex)
  *
+ * Obs. One instance per core
  */
-func (peer *Peer) RoutineSequential() {
-       for work := range peer.queueOutbound {
-               work.wg.Wait()
-               if work.packet == nil {
-                       continue
-               }
-               if peer.endpoint == nil {
-                       continue
-               }
-               peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
-       }
-}
-
-func (device *Device) RoutineEncryptionWorker() {
+func (device *Device) RoutineEncryption() {
        var nonce [chacha20poly1305.NonceSize]byte
-       for work := range device.queueWorkOutbound {
+       for work := range device.queue.encryption {
+
                // pad packet
 
                padding := device.mtu - len(work.packet)
                if padding < 0 {
+                       // drop
                        work.packet = nil
-                       work.wg.Done()
+                       work.mutex.Unlock()
                }
                for n := 0; n < padding; n += 1 {
                        work.packet = append(work.packet, 0)
@@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
                        work.packet,
                        nil,
                )
-               work.wg.Done()
+               work.mutex.Unlock()
+       }
+}
+
+/* Sequentially reads packets from queue and sends to endpoint
+ *
+ * Obs. Single instance per peer.
+ * The routine terminates then the outbound queue is closed.
+ */
+func (peer *Peer) RoutineSequential() {
+       for work := range peer.queue.outbound {
+               work.mutex.Lock()
+               func() {
+                       peer.mutex.RLock()
+                       defer peer.mutex.RUnlock()
+                       if work.packet == nil {
+                               return
+                       }
+                       if peer.endpoint == nil {
+                               return
+                       }
+                       peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
+                       peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
+               }()
+               work.mutex.Unlock()
        }
 }
index 746c1b4fbda73c10526453661def7f045acdaa8e..4049167692e3c12a11b82254ed8d24bd4f9b7fb2 100644 (file)
@@ -1,15 +1,20 @@
 package main
 
 import (
+       "errors"
        "net"
 )
 
 /* Binary trie
+ *
+ * The net.IPs used here are not formatted the
+ * same way as those created by the "net" functions.
+ * Here the IPs are slices of either 4 or 16 byte (not always 16)
  *
  * Syncronization done seperatly
  * See: routing.go
  *
- * Todo: Better commenting
+ * TODO: Better commenting
  */
 
 type Trie struct {
@@ -24,7 +29,7 @@ type Trie struct {
 }
 
 /* Finds length of matching prefix
- * Maybe there is a faster way
+ * TODO: Make faster
  *
  * Assumption: len(ip1) == len(ip2)
  */
@@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
        r := node.child[1].Count()
        return l + r
 }
+
+func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
+       if node.peer == p {
+               var mask net.IPNet
+               mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
+               if len(node.bits) == net.IPv4len {
+                       mask.IP = net.IPv4(
+                               node.bits[0],
+                               node.bits[1],
+                               node.bits[2],
+                               node.bits[3],
+                       )
+               } else if len(node.bits) == net.IPv6len {
+                       mask.IP = node.bits
+               } else {
+                       panic(errors.New("bug: unexpected address length"))
+               }
+               results = append(results, mask)
+       }
+       node.child[0].AllowedIPs(p, results)
+       node.child[1].AllowedIPs(p, results)
+}
index 1a8bb822c220dee9ea4e86e32a1e82467f35b536..594754a561df36e35007107e5694a3f681e0ddf1 100644 (file)
@@ -1,6 +1,6 @@
 package main
 
-type TUN interface {
+type TUNDevice interface {
        Read([]byte) (int, error)
        Write([]byte) (int, error)
        Name() string
index d545dfac7f0c6b8f74492aa9bd7128557f6dfc12..cbbcb70812d3de606a83bfb8e4d1508dede7267f 100644 (file)
@@ -9,9 +9,7 @@ import (
        "unsafe"
 )
 
-/* Platform dependent functions for interacting with
- * TUN devices on linux systems
- *
+/* Implementation of the TUN device interface for linux
  */
 
 const CloneDevicePath = "/dev/net/tun"
@@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
        return tun.fd.Read(d)
 }
 
-func CreateTUN(name string) (TUN, error) {
+func CreateTUN(name string) (TUNDevice, error) {
        // Open clone device
        fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
        if err != nil {
@@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
        }
 
        // Prepare ifreq struct
-       var ifr [18]byte
+       var ifr [128]byte
        var flags uint16 = IFF_TUN | IFF_NO_PI
        nameBytes := []byte(name)
        if len(nameBytes) >= IFNAMSIZ {