]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
First set of code review patches
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 4 Aug 2017 14:15:53 +0000 (16:15 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 4 Aug 2017 14:15:53 +0000 (16:15 +0200)
15 files changed:
src/config.go
src/constants.go
src/device.go
src/index.go
src/macs.go
src/noise_helpers.go
src/noise_protocol.go
src/noise_types.go
src/receive.go
src/send.go
src/timers.go
src/trie.go
src/tun.go
src/tun_linux.go
src/uapi_linux.go

index 72a604fa49793be07d0c4e99f4d12bcc27765dd7..e2d7f200c1efd3ec44b695acf3f204b9de6406b0 100644 (file)
@@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        send(fmt.Sprintf("persistent_keepalive_interval=%d",
                                atomic.LoadUint64(&peer.persistentKeepaliveInterval),
                        ))
+
                        for _, ip := range device.routingTable.AllowedIPs(peer) {
                                send("allowed_ip=" + ip.String())
                        }
@@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        logDebug := device.log.Debug
 
        var peer *Peer
+
+       deviceConfig := true
+
        for scanner.Scan() {
 
                // parse line
@@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                }
                parts := strings.Split(line, "=")
                if len(parts) != 2 {
-                       return &IPCError{Code: ipcErrorNoKeyValue}
+                       return &IPCError{Code: ipcErrorProtocol}
                }
                key := parts[0]
                value := parts[1]
 
-               switch key {
+               /* device configuration */
 
-               /* interface configuration */
+               if deviceConfig {
 
-               case "private_key":
-                       var sk NoisePrivateKey
-                       if value == "" {
-                               device.SetPrivateKey(sk)
-                       } else {
-                               err := sk.FromHex(value)
-                               if err != nil {
-                                       logError.Println("Failed to set private_key:", err)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                       switch key {
+                       case "private_key":
+                               var sk NoisePrivateKey
+                               if value == "" {
+                                       device.SetPrivateKey(sk)
+                               } else {
+                                       err := sk.FromHex(value)
+                                       if err != nil {
+                                               logError.Println("Failed to set private_key:", err)
+                                               return &IPCError{Code: ipcErrorInvalid}
+                                       }
+                                       device.SetPrivateKey(sk)
                                }
-                               device.SetPrivateKey(sk)
-                       }
 
-               case "listen_port":
-                       port, err := strconv.ParseUint(value, 10, 16)
-                       if err != nil {
-                               logError.Println("Failed to set listen_port:", err)
-                               return &IPCError{Code: ipcErrorInvalidValue}
-                       }
-                       netc := &device.net
-                       netc.mutex.Lock()
-                       if netc.addr.Port != int(port) {
-                               if netc.conn != nil {
-                                       netc.conn.Close()
+                       case "listen_port":
+                               port, err := strconv.ParseUint(value, 10, 16)
+                               if err != nil {
+                                       logError.Println("Failed to set listen_port:", err)
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
-                               netc.addr.Port = int(port)
-                               netc.conn, err = net.ListenUDP("udp", netc.addr)
-                       }
-                       netc.mutex.Unlock()
-                       if err != nil {
-                               logError.Println("Failed to create UDP listener:", err)
-                               return &IPCError{Code: ipcErrorInvalidValue}
-                       }
+                               netc := &device.net
+                               netc.mutex.Lock()
+                               if netc.addr.Port != int(port) {
+                                       if netc.conn != nil {
+                                               netc.conn.Close()
+                                       }
+                                       netc.addr.Port = int(port)
+                                       netc.conn, err = net.ListenUDP("udp", netc.addr)
+                               }
+                               netc.mutex.Unlock()
+                               if err != nil {
+                                       logError.Println("Failed to create UDP listener:", err)
+                                       return &IPCError{Code: ipcErrorIO}
+                               }
+                               // TODO: Clear source address of all peers
 
-               case "fwmark":
-                       logError.Println("FWMark not handled yet")
+                       case "fwmark":
+                               logError.Println("FWMark not handled yet")
+                               // TODO: Clear source address of all peers
 
-               case "public_key":
-                       var pubKey NoisePublicKey
-                       err := pubKey.FromHex(value)
-                       if err != nil {
-                               logError.Println("Failed to get peer by public_key:", err)
-                               return &IPCError{Code: ipcErrorInvalidValue}
-                       }
-                       device.mutex.RLock()
-                       peer, _ = device.peers[pubKey]
-                       device.mutex.RUnlock()
-                       if peer == nil {
-                               peer = device.NewPeer(pubKey)
-                       }
+                       case "public_key":
 
-               case "replace_peers":
-                       if value == "true" {
-                               device.RemoveAllPeers()
-                       } else {
-                               logError.Println("Failed to set replace_peers, invalid value:", value)
-                               return &IPCError{Code: ipcErrorInvalidValue}
-                       }
+                               // switch to peer configuration
 
-               default:
+                               deviceConfig = false
 
-                       /* peer configuration */
+                       case "replace_peers":
+                               if value != "true" {
+                                       logError.Println("Failed to set replace_peers, invalid value:", value)
+                                       return &IPCError{Code: ipcErrorInvalid}
+                               }
+                               device.RemoveAllPeers()
 
-                       if peer == nil {
-                               logError.Println("No peer referenced, before peer operation")
-                               return &IPCError{Code: ipcErrorNoPeer}
+                       default:
+                               logError.Println("Invalid UAPI key (device configuration):", key)
+                               return &IPCError{Code: ipcErrorInvalid}
                        }
+               }
+
+               /* peer configuration */
+
+               if !deviceConfig {
 
                        switch key {
 
+                       case "public_key":
+                               var pubKey NoisePublicKey
+                               err := pubKey.FromHex(value)
+                               if err != nil {
+                                       logError.Println("Failed to get peer by public_key:", err)
+                                       return &IPCError{Code: ipcErrorInvalid}
+                               }
+
+                               // check if public key of peer equal to device
+
+                               device.mutex.RLock()
+                               if device.publicKey.Equals(pubKey) {
+                                       device.mutex.RUnlock()
+                                       logError.Println("Public key of peer matches private key of device")
+                                       return &IPCError{Code: ipcErrorInvalid}
+                               }
+
+                               // find peer referenced
+
+                               peer, _ = device.peers[pubKey]
+                               device.mutex.RUnlock()
+                               if peer == nil {
+                                       peer = device.NewPeer(pubKey)
+                               }
+
                        case "remove":
+                               if value != "true" {
+                                       logError.Println("Failed to set remove, invalid value:", value)
+                                       return &IPCError{Code: ipcErrorInvalid}
+                               }
                                device.RemovePeer(peer.handshake.remoteStatic)
                                logDebug.Println("Removing", peer.String())
                                peer = nil
@@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }()
                                if err != nil {
                                        logError.Println("Failed to set preshared_key:", err)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
 
                        case "endpoint":
+                               // TODO: Only IP and port
                                addr, err := net.ResolveUDPAddr("udp", value)
                                if err != nil {
                                        logError.Println("Failed to set endpoint:", value)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
                                peer.mutex.Lock()
                                peer.endpoint = addr
                                peer.mutex.Unlock()
 
                        case "persistent_keepalive_interval":
-                               secs, err := strconv.ParseInt(value, 10, 64)
-                               if secs < 0 || err != nil {
+
+                               // update keep-alive interval
+
+                               secs, err := strconv.ParseUint(value, 10, 16)
+                               if err != nil {
                                        logError.Println("Failed to set persistent_keepalive_interval:", err)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
-                               atomic.StoreUint64(
+
+                               old := atomic.SwapUint64(
                                        &peer.persistentKeepaliveInterval,
-                                       uint64(secs),
+                                       secs,
                                )
 
+                               // send immediate keep-alive
+
+                               if old == 0 && secs != 0 {
+                                       up, err := device.tun.IsUp()
+                                       if err != nil {
+                                               logError.Println("Failed to get tun device status:", err)
+                                               return &IPCError{Code: ipcErrorIO}
+                                       }
+                                       if up {
+                                               peer.SendKeepAlive()
+                                       }
+                               }
+
                        case "replace_allowed_ips":
-                               if value == "true" {
-                                       device.routingTable.RemovePeer(peer)
-                               } else {
+                               if value != "true" {
                                        logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
+                               device.routingTable.RemovePeer(peer)
 
                        case "allowed_ip":
                                _, network, err := net.ParseCIDR(value)
                                if err != nil {
                                        logError.Println("Failed to set allowed_ip:", err)
-                                       return &IPCError{Code: ipcErrorInvalidValue}
+                                       return &IPCError{Code: ipcErrorInvalid}
                                }
                                ones, _ := network.Mask.Size()
                                device.routingTable.Insert(network.IP, uint(ones), peer)
 
                        default:
-                               logError.Println("Invalid UAPI key:", key)
-                               return &IPCError{Code: ipcErrorInvalidKey}
+                               logError.Println("Invalid UAPI key (peer configuration):", key)
+                               return &IPCError{Code: ipcErrorInvalid}
                        }
                }
        }
@@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
 func ipcHandle(device *Device, socket net.Conn) {
 
+       // create buffered read/writer
+
        defer socket.Close()
 
        buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) {
                return
        }
 
-       switch op {
+       // handle operation
 
+       var status *IPCError
+
+       switch op {
        case "set=1\n":
                device.log.Debug.Println("Config, set operation")
-               err := ipcSetOperation(device, buffered)
-               if err != nil {
-                       fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
-               } else {
-                       fmt.Fprintf(buffered, "errno=0\n\n")
-               }
-               return
+               status = ipcSetOperation(device, buffered)
 
        case "get=1\n":
                device.log.Debug.Println("Config, get operation")
-               err := ipcGetOperation(device, buffered)
-               if err != nil {
-                       fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
-               } else {
-                       fmt.Fprintf(buffered, "errno=0\n\n")
-               }
-               return
+               status = ipcGetOperation(device, buffered)
 
        default:
                device.log.Error.Println("Invalid UAPI operation:", op)
+               return
+       }
+
+       // write status
 
+       if status != nil {
+               device.log.Error.Println(status)
+               fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+       } else {
+               fmt.Fprintf(buffered, "errno=0\n\n")
        }
 }
index 09d33d858e1128da69a62a9d5294eb12a0cfb35a..f09ded62678db60fdc22079ce9e879a207d27a04 100644 (file)
@@ -16,6 +16,7 @@ const (
        KeepaliveTimeout        = time.Second * 10
        CookieRefreshTime       = time.Second * 120
        MaxHandshakeAttemptTime = time.Second * 90
+       PaddingMultiple         = 16
 )
 
 const (
@@ -31,5 +32,5 @@ const (
        QueueHandshakeSize     = 1024
        QueueHandshakeBusySize = QueueHandshakeSize / 8
        MinMessageSize         = MessageTransportSize // size of keep-alive
-       MaxMessageSize         = (1 << 16) - 1
+       MaxMessageSize         = ((1 << 16) - 1) + MessageTransportHeaderSize
 )
index 1185d609a3703318d40dd9bcf11e7aa098f69d6e..de96f0b19073bb4aadbca712ae476e6cf3220a52 100644 (file)
@@ -1,6 +1,8 @@
 package main
 
 import (
+       "errors"
+       "fmt"
        "net"
        "runtime"
        "sync"
@@ -10,6 +12,7 @@ import (
 
 type Device struct {
        mtu       int32
+       tun       TUNDevice
        log       *Logger // collection of loggers for levels
        idCounter uint    // for assigning debug ids to peers
        fwMark    uint32
@@ -43,24 +46,46 @@ type Device struct {
        mac         MACStateDevice
 }
 
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
+       // check if public key is matching any peer
+
+       publicKey := sk.publicKey()
+       for _, peer := range device.peers {
+               h := &peer.handshake
+               h.mutex.RLock()
+               if h.remoteStatic.Equals(publicKey) {
+                       h.mutex.RUnlock()
+                       return errors.New("Private key matches public key of peer")
+               }
+               h.mutex.RUnlock()
+       }
+
        // update key material
 
        device.privateKey = sk
-       device.publicKey = sk.publicKey()
-       device.mac.Init(device.publicKey)
+       device.publicKey = publicKey
+       device.mac.Init(publicKey)
 
        // do DH precomputations
 
+       isZero := device.privateKey.IsZero()
+
        for _, peer := range device.peers {
                h := &peer.handshake
                h.mutex.Lock()
-               h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+               if isZero {
+                       h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+               } else {
+                       h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+               }
+               fmt.Println(h.precomputedStaticStatic)
                h.mutex.Unlock()
        }
+
+       return nil
 }
 
 func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
@@ -77,6 +102,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
+       device.tun = tun
        device.log = NewLogger(logLevel)
        device.peers = make(map[NoisePublicKey]*Peer)
        device.indices.Init()
@@ -119,22 +145,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        }
 
        go device.RoutineBusyMonitor()
-       go device.RoutineMTUUpdater(tun)
-       go device.RoutineWriteToTUN(tun)
-       go device.RoutineReadFromTUN(tun)
+       go device.RoutineMTUUpdater()
+       go device.RoutineWriteToTUN()
+       go device.RoutineReadFromTUN()
        go device.RoutineReceiveIncomming()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 
        return device
 }
 
-func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
+func (device *Device) RoutineMTUUpdater() {
        logError := device.log.Error
        for ; ; time.Sleep(5 * time.Second) {
 
                // load updated MTU
 
-               mtu, err := tun.MTU()
+               mtu, err := device.tun.MTU()
                if err != nil {
                        logError.Println("Failed to load updated MTU of device:", err)
                        continue
index 44b49744c9c531725d66cbc8cf8fcb2c9999f37b..e518b0f8c2adeb39759f02806c0d1227f24d16ce 100644 (file)
@@ -3,6 +3,7 @@ package main
 import (
        "crypto/rand"
        "sync"
+       "unsafe"
 )
 
 /* Index=0 is reserved for unset indecies
@@ -23,14 +24,7 @@ type IndexTable struct {
 func randUint32() (uint32, error) {
        var buff [4]byte
        _, err := rand.Read(buff[:])
-       id := uint32(buff[0])
-       id <<= 8
-       id |= uint32(buff[1])
-       id <<= 8
-       id |= uint32(buff[2])
-       id <<= 8
-       id |= uint32(buff[3])
-       return id, err
+       return *((*uint32)(unsafe.Pointer(&buff))), err
 }
 
 func (table *IndexTable) Init() {
index 841ef318f23ba4d216279398f8fa1a8ea01d6fc6..beb5f7689c32019b7d11e26487327881d2fb8693 100644 (file)
@@ -3,7 +3,6 @@ package main
 import (
        "crypto/hmac"
        "crypto/rand"
-       "errors"
        "golang.org/x/crypto/blake2s"
        "net"
        "sync"
@@ -15,14 +14,14 @@ type MACStateDevice struct {
        refreshed time.Time
        secret    [blake2s.Size]byte
        keyMAC1   [blake2s.Size]byte
-       keyMAC2   [blake2s.Size]byte
+       keyMAC2   [blake2s.Size]byte // TODO: Change to more descriptive size constant, rename to something.
 }
 
 type MACStatePeer struct {
        mutex     sync.RWMutex
        cookieSet time.Time
        cookie    [blake2s.Size128]byte
-       lastMAC1  [blake2s.Size128]byte
+       lastMAC1  [blake2s.Size128]byte // TODO: Check if set
        keyMAC1   [blake2s.Size]byte
        keyMAC2   [blake2s.Size]byte
 }
@@ -83,7 +82,7 @@ func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool {
                port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
                mac, _ := blake2s.New128(state.secret[:])
                mac.Write(addr.IP)
-               mac.Write(port[:])
+               mac.Write(port[:]) // TODO: Be faster and more platform dependent?
                mac.Sum(cookie[:0])
        }()
 
@@ -130,7 +129,7 @@ func (device *Device) CreateMessageCookieReply(
                port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
                mac, _ := blake2s.New128(state.secret[:])
                mac.Write(addr.IP)
-               mac.Write(port[:])
+               mac.Write(port[:]) // TODO: Do whatever we did above
                mac.Sum(cookie[:0])
        }()
 
@@ -196,6 +195,7 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
        if err != nil {
                return false
        }
+
        state.cookieSet = time.Now()
        state.cookie = cookie
        return true
@@ -229,10 +229,6 @@ func (state *MACStatePeer) Init(pk NoisePublicKey) {
 func (state *MACStatePeer) AddMacs(msg []byte) {
        size := len(msg)
 
-       if size < blake2s.Size128*2 {
-               panic(errors.New("bug: message too short"))
-       }
-
        startMac1 := size - (blake2s.Size128 * 2)
        startMac2 := size - blake2s.Size128
 
@@ -250,6 +246,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
                mac.Sum(mac1[:0])
        }()
        copy(state.lastMAC1[:], mac1)
+       // TODO: Set lastMac flag
 
        // set mac2
 
index 1e622a5bbf2553fbfc3a82c2b3d602eeadd4f201..105f78f4829096c3af5b9e6be4b71e35efef260a 100644 (file)
@@ -47,6 +47,14 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
        return
 }
 
+func isZero(val []byte) bool {
+       var acc byte
+       for _, b := range val {
+               acc |= b
+       }
+       return acc == 0
+}
+
 /* curve25519 wrappers */
 
 func newPrivateKey() (sk NoisePrivateKey, err error) {
index e2ff5736eb968bcc9b34e7ff404ecfc3b01687ce..5c776a81a192ad2677eee53a25e06d149e14bfb2 100644 (file)
@@ -135,6 +135,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
 
+       if isZero(handshake.precomputedStaticStatic[:]) {
+               return nil, errors.New("Static shared secret is zero")
+       }
+
        // create ephemeral key
 
        var err error
@@ -226,7 +230,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        if peer == nil {
                return nil
        }
+
        handshake := &peer.handshake
+       if isZero(handshake.precomputedStaticStatic[:]) {
+               return nil
+       }
 
        // verify identity
 
@@ -472,6 +480,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
        func() {
                kp.mutex.Lock()
                defer kp.mutex.Unlock()
+               // TODO: Adapt kernel behavior noise.c:161
                if isInitiator {
                        if kp.previous != nil {
                                kp.previous.send = nil
index 5ebc130c94b9d3267073949f25d6e3a1bbb452af..1a944dfe908fe51b3cd15ac0df2f630b92953145 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "crypto/subtle"
        "encoding/hex"
        "errors"
        "golang.org/x/crypto/chacha20poly1305"
@@ -31,12 +32,12 @@ func loadExactHex(dst []byte, src string) error {
 }
 
 func (key NoisePrivateKey) IsZero() bool {
-       for _, b := range key[:] {
-               if b != 0 {
-                       return false
-               }
-       }
-       return true
+       var zero NoisePrivateKey
+       return key.Equals(zero)
+}
+
+func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
+       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
 }
 
 func (key *NoisePrivateKey) FromHex(src string) error {
@@ -55,6 +56,15 @@ func (key NoisePublicKey) ToHex() string {
        return hex.EncodeToString(key[:])
 }
 
+func (key NoisePublicKey) IsZero() bool {
+       var zero NoisePublicKey
+       return key.Equals(zero)
+}
+
+func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
+       return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
 func (key *NoiseSymmetricKey) FromHex(src string) error {
        return loadExactHex(key[:], src)
 }
index 700b8940883217d818eb79e1ff9ad46ae4f3a948..fb5c51fa5a5e02c1e53697642e922c8c1c314ace 100644 (file)
@@ -73,6 +73,8 @@ func (device *Device) addToHandshakeQueue(
 }
 
 /* Routine determining the busy state of the interface
+ *
+ * TODO: Under load for some time
  */
 func (device *Device) RoutineBusyMonitor() {
        samples := 0
@@ -131,6 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
                        buffer = device.GetMessageBuffer()
                }
 
+               // TODO: Take writelock to sleep
                device.net.mutex.RLock()
                conn := device.net.conn
                device.net.mutex.RUnlock()
@@ -139,6 +142,7 @@ func (device *Device) RoutineReceiveIncomming() {
                        continue
                }
 
+               // TODO: Wait for new conn or message
                conn.SetReadDeadline(time.Now().Add(time.Second))
 
                size, raddr, err := conn.ReadFromUDP(buffer[:])
@@ -156,6 +160,8 @@ func (device *Device) RoutineReceiveIncomming() {
 
                        case MessageInitiationType, MessageResponseType:
 
+                               // TODO: Check size early
+
                                // add to handshake queue
 
                                device.addToHandshakeQueue(
@@ -171,6 +177,8 @@ func (device *Device) RoutineReceiveIncomming() {
 
                        case MessageCookieReplyType:
 
+                               // TODO: Queue all the things
+
                                // verify and update peer cookie state
 
                                if len(packet) != MessageCookieReplySize {
@@ -250,7 +258,7 @@ func (device *Device) RoutineDecryption() {
                // check if dropped
 
                if elem.IsDropped() {
-                       elem.mutex.Unlock()
+                       elem.mutex.Unlock() // TODO: Make consistent with send
                        continue
                }
 
@@ -318,6 +326,7 @@ func (device *Device) RoutineHandshake() {
                                        logError.Println("Failed to create cookie reply:", err)
                                        return
                                }
+                               // TODO: Use temp
                                writer := bytes.NewBuffer(elem.packet[:0])
                                binary.Write(writer, binary.LittleEndian, reply)
                                elem.packet = writer.Bytes()
@@ -330,6 +339,8 @@ func (device *Device) RoutineHandshake() {
 
                        // ratelimit
 
+                       // TODO: Only ratelimit when busy
+
                        if !device.ratelimiter.Allow(elem.source.IP) {
                                return
                        }
@@ -364,9 +375,14 @@ func (device *Device) RoutineHandshake() {
                                        )
                                        return
                                }
-                               peer.TimerPacketReceived()
+
+                               // update timers
+
+                               peer.TimerAnyAuthenticatedPacketTraversal()
+                               peer.TimerAnyAuthenticatedPacketReceived()
 
                                // update endpoint
+                               // TODO: Add a race condition \s
 
                                peer.mutex.Lock()
                                peer.endpoint = elem.source
@@ -381,6 +397,7 @@ func (device *Device) RoutineHandshake() {
                                }
 
                                peer.TimerEphemeralKeyCreated()
+                               peer.NewKeyPair()
 
                                logDebug.Println("Creating response message for", peer.String())
 
@@ -392,8 +409,7 @@ func (device *Device) RoutineHandshake() {
                                // send response
 
                                peer.SendBuffer(packet)
-                               peer.TimerPacketSent()
-                               peer.NewKeyPair()
+                               peer.TimerAnyAuthenticatedPacketTraversal()
 
                        case MessageResponseType:
 
@@ -423,8 +439,14 @@ func (device *Device) RoutineHandshake() {
                                        return
                                }
 
-                               peer.TimerPacketReceived()
+                               // update timers
+
+                               peer.TimerAnyAuthenticatedPacketTraversal()
+                               peer.TimerAnyAuthenticatedPacketReceived()
                                peer.TimerHandshakeComplete()
+
+                               // derive key-pair
+
                                peer.NewKeyPair()
                                peer.SendKeepAlive()
 
@@ -467,8 +489,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                return
                        }
 
-                       peer.TimerPacketReceived()
-                       peer.TimerTransportReceived()
+                       peer.TimerAnyAuthenticatedPacketTraversal()
+                       peer.TimerAnyAuthenticatedPacketReceived()
                        peer.KeepKeyFreshReceiving()
 
                        // check if using new key-pair
@@ -504,6 +526,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                                field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
                                length := binary.BigEndian.Uint16(field)
+                               // TODO: check length of packet & NOT TOO SMALL either
                                elem.packet = elem.packet[:length]
 
                                // verify IPv4 source
@@ -525,6 +548,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
                                length := binary.BigEndian.Uint16(field)
                                length += ipv6.HeaderLen
+                               // TODO: check length of packet
                                elem.packet = elem.packet[:length]
 
                                // verify IPv6 source
@@ -542,11 +566,13 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
                        device.addToInboundQueue(device.queue.inbound, elem)
+
+                       // TODO: move TUN write into per peer routine
                }()
        }
 }
 
-func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
+func (device *Device) RoutineWriteToTUN() {
 
        logError := device.log.Error
        logDebug := device.log.Debug
@@ -557,7 +583,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
                case <-device.signal.stop:
                        return
                case elem := <-device.queue.inbound:
-                       _, err := tun.Write(elem.packet)
+                       _, err := device.tun.Write(elem.packet)
                        device.PutMessageBuffer(elem.buffer)
                        if err != nil {
                                logError.Println("Failed to write packet to TUN device:", err)
index 37078b97680b1920d12670dbcf3bad0b90a00058..fc3573264521e0193147628710c89af6da2969a8 100644 (file)
@@ -110,17 +110,19 @@ func addToEncryptionQueue(
 }
 
 func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+       peer.device.net.mutex.RLock()
+       defer peer.device.net.mutex.RUnlock()
 
        peer.mutex.RLock()
+       defer peer.mutex.RUnlock()
+
        endpoint := peer.endpoint
-       peer.mutex.RUnlock()
+       conn := peer.device.net.conn
+
        if endpoint == nil {
                return 0, ErrorNoEndpoint
        }
 
-       peer.device.net.mutex.RLock()
-       conn := peer.device.net.conn
-       peer.device.net.mutex.RUnlock()
        if conn == nil {
                return 0, ErrorNoConnection
        }
@@ -133,13 +135,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
  *
  * Obs. Single instance per TUN device
  */
-func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+func (device *Device) RoutineReadFromTUN() {
 
-       if tun == nil {
+       if device.tun == nil {
                return
        }
 
-       elem := device.NewOutboundElement()
+       var elem *QueueOutboundElement
 
        logDebug := device.log.Debug
        logError := device.log.Error
@@ -153,32 +155,38 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
                        elem = device.NewOutboundElement()
                }
 
+               // TODO: THIS!
                elem.packet = elem.buffer[MessageTransportHeaderSize:]
-               size, err := tun.Read(elem.packet)
+               size, err := device.tun.Read(elem.packet)
                if err != nil {
-
-                       // stop process
-
                        logError.Println("Failed to read packet from TUN device:", err)
                        device.Close()
                        return
                }
 
-               elem.packet = elem.packet[:size]
-               if len(elem.packet) < ipv4.HeaderLen {
-                       logError.Println("Packet too short, length:", size)
+               if size == 0 {
                        continue
                }
 
+               println(size, err)
+
+               elem.packet = elem.packet[:size]
+
                // lookup peer
 
                var peer *Peer
                switch elem.packet[0] >> 4 {
                case ipv4.Version:
+                       if len(elem.packet) < ipv4.HeaderLen {
+                               continue
+                       }
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
                        peer = device.routingTable.LookupIPv4(dst)
 
                case ipv6.Version:
+                       if len(elem.packet) < ipv6.HeaderLen {
+                               continue
+                       }
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
                        peer = device.routingTable.LookupIPv6(dst)
 
@@ -190,10 +198,15 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
                        continue
                }
 
+               // check if known endpoint
+
+               peer.mutex.RLock()
                if peer.endpoint == nil {
+                       peer.mutex.RUnlock()
                        logDebug.Println("No known endpoint for peer", peer.String())
                        continue
                }
+               peer.mutex.RUnlock()
 
                // insert into nonce/pre-handshake queue
 
@@ -334,8 +347,12 @@ func (device *Device) RoutineEncryption() {
                // pad content to MTU size
 
                mtu := int(atomic.LoadInt32(&device.mtu))
-               for i := len(elem.packet); i < mtu; i++ {
-                       elem.packet = append(elem.packet, 0)
+               pad := len(elem.packet) % PaddingMultiple
+               if pad > 0 {
+                       for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
+                               elem.packet = append(elem.packet, 0)
+                       }
+                       // TODO: How good is this code
                }
 
                // encrypt content (append to header)
@@ -390,7 +407,7 @@ func (peer *Peer) RoutineSequentialSender() {
 
                        // update timers
 
-                       peer.TimerPacketSent()
+                       peer.TimerAnyAuthenticatedPacketTraversal()
                        if len(elem.packet) != MessageKeepaliveSize {
                                peer.TimerDataSent()
                        }
index 5a16e9bfe3280352b84c10618c3ed5ff30db52bb..1be85f07cede46033eabacbfaf582ae6840ec4a2 100644 (file)
@@ -60,10 +60,8 @@ func (peer *Peer) SendKeepAlive() bool {
        return true
 }
 
-/* Authenticated data packet send
- * Always called together with peer.EventPacketSend
- *
- * - Start new handshake timer
+/* Event:
+ * Sent non-empty (authenticated) transport message
  */
 func (peer *Peer) TimerDataSent() {
        timerStop(peer.timer.keepalivePassive)
@@ -75,8 +73,6 @@ func (peer *Peer) TimerDataSent() {
 
 /* Event:
  * Received non-empty (authenticated) transport message
- *
- * - Start passive keep-alive timer
  */
 func (peer *Peer) TimerDataReceived() {
        if peer.timer.pendingKeepalivePassive {
@@ -88,17 +84,16 @@ func (peer *Peer) TimerDataReceived() {
 }
 
 /* Event:
- * Any (authenticated) transport message received
- * (keep-alive or data)
+ * Any (authenticated) packet received
  */
-func (peer *Peer) TimerTransportReceived() {
+func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
        timerStop(peer.timer.newHandshake)
 }
 
 /* Event:
- * Any packet send to the peer.
+ * Any authenticated packet send / received.
  */
-func (peer *Peer) TimerPacketSent() {
+func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
        interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
        if interval > 0 {
                duration := time.Duration(interval) * time.Second
@@ -106,13 +101,6 @@ func (peer *Peer) TimerPacketSent() {
        }
 }
 
-/* Event:
- * Any authenticated packet received from peer
- */
-func (peer *Peer) TimerPacketReceived() {
-       peer.TimerPacketSent()
-}
-
 /* Called after succesfully completing a handshake.
  * i.e. after:
  *
@@ -129,7 +117,9 @@ func (peer *Peer) TimerHandshakeComplete() {
        peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
 }
 
-/* Called whenever an ephemeral key is generated
+/* Event:
+ * An ephemeral key is generated
+ *
  * i.e after:
  *
  * CreateMessageInitiation
@@ -257,7 +247,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
                select {
                case <-peer.signal.handshakeBegin:
-                       signalSend(peer.signal.handshakeBegin)
                case <-peer.signal.stop:
                        return
                }
@@ -303,7 +292,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                        binary.Write(writer, binary.LittleEndian, msg)
                        packet := writer.Bytes()
                        peer.mac.AddMacs(packet)
-                       peer.TimerPacketSent()
 
                        _, err = peer.SendBuffer(packet)
                        if err != nil {
@@ -314,6 +302,8 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                                continue
                        }
 
+                       peer.TimerAnyAuthenticatedPacketTraversal()
+
                        // set timeout
 
                        timeout := time.NewTimer(RekeyTimeout)
@@ -337,7 +327,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                                continue
 
                        }
-
                }
 
                // allow new signal to be set
index e81b5b62c95409b2bb0c9071dff0f38b32e8d59e..aa96a8adcbc50c69b27d048d377e0d5bee002345 100644 (file)
@@ -32,11 +32,14 @@ type Trie struct {
 /* Finds length of matching prefix
  * TODO: Make faster
  *
- * Assumption: len(ip1) == len(ip2)
+ * Assumption:
+ *       len(ip1) == len(ip2)
+ *       len(ip1) mod 4 = 0
  */
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1 []byte, ip2 []byte) uint {
        var i uint
-       size := uint(len(ip1))
+       size := uint(len(ip1)) / 4
+
        for i = 0; i < size; i++ {
                v := ip1[i] ^ ip2[i]
                if v != 0 {
index f529c54f5ea25a64a9e1368ba39757e3d74da385..d782bd57786bf1c541fb5196e4fe765cde01ac0f 100644 (file)
@@ -9,6 +9,7 @@ const DefaultMTU = 1420
 type TUNDevice interface {
        Read([]byte) (int, error)  // read a packet from the device (without any additional headers)
        Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
+       IsUp() (bool, error)       // is the interface up?
        MTU() (int, error)         // returns the MTU of the device
        Name() string              // returns the current name
 }
index 261d14254d0cb8ec7f2f11d7a1a4780dbc53a379..d0e2f470592db6d177f1ef50cb0e035bb26f14b3 100644 (file)
@@ -7,6 +7,7 @@ import (
        "encoding/binary"
        "errors"
        "golang.org/x/sys/unix"
+       "net"
        "os"
        "strings"
        "unsafe"
@@ -19,6 +20,11 @@ type NativeTun struct {
        name string
 }
 
+func (tun *NativeTun) IsUp() (bool, error) {
+       inter, err := net.InterfaceByName(tun.name)
+       return inter.Flags&net.FlagUp != 0, err
+}
+
 func (tun *NativeTun) Name() string {
        return tun.name
 }
index fd83918b9b4e98283aeea59bbe98cebe5093fc66..d6d78e733d39a2896464febcbb54cf45ba46b524 100644 (file)
@@ -11,13 +11,12 @@ import (
 )
 
 const (
-       ipcErrorIO           = int64(unix.EIO)
-       ipcErrorNoPeer       = int64(unix.EPROTO)
-       ipcErrorNoKeyValue   = int64(unix.EPROTO)
-       ipcErrorInvalidKey   = int64(unix.EPROTO)
-       ipcErrorInvalidValue = int64(unix.EPROTO)
-       socketDirectory      = "/var/run/wireguard"
-       socketName           = "%s.sock"
+       ipcErrorIO         = -int64(unix.EIO)
+       ipcErrorNotDefined = -int64(unix.ENODEV)
+       ipcErrorProtocol   = -int64(unix.EPROTO)
+       ipcErrorInvalid    = -int64(unix.EINVAL)
+       socketDirectory    = "/var/run/wireguard"
+       socketName         = "%s.sock"
 )
 
 /* TODO: