]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Number of fixes in response to code review
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 7 Aug 2017 13:25:04 +0000 (15:25 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 7 Aug 2017 13:25:04 +0000 (15:25 +0200)
This version cannot complete a handshake.
The program will panic upon receiving any message on the UDP socket.

12 files changed:
src/config.go
src/constants.go
src/daemon_linux.go
src/device.go
src/macs.go
src/peer.go
src/receive.go
src/send.go
src/timers.go
src/tun.go
src/tun_linux.go
src/uapi_linux.go

index e2d7f200c1efd3ec44b695acf3f204b9de6406b0..d952a3a9b57c120ba826fc34d653f9b57abc59f8 100644 (file)
@@ -84,13 +84,47 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        return nil
 }
 
+func updateUDPConn(device *Device) error {
+       var err error
+       netc := &device.net
+       netc.mutex.Lock()
+
+       // close existing connection
+
+       if netc.conn != nil {
+               netc.conn.Close()
+               netc.conn = nil
+       }
+
+       // open new existing connection
+
+       conn, err := net.ListenUDP("udp", netc.addr)
+       if err == nil {
+               netc.conn = conn
+               signalSend(device.signal.newUDPConn)
+       }
+
+       netc.mutex.Unlock()
+       return err
+}
+
+func closeUDPConn(device *Device) {
+       device.net.mutex.Lock()
+       device.net.conn = nil
+       device.net.mutex.Unlock()
+       println("send signal")
+       signalSend(device.signal.newUDPConn)
+}
+
 func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        scanner := bufio.NewScanner(socket)
+       logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
 
        var peer *Peer
 
+       dummy := false
        deviceConfig := true
 
        for scanner.Scan() {
@@ -135,17 +169,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                netc := &device.net
                                netc.mutex.Lock()
                                if netc.addr.Port != int(port) {
-                                       if netc.conn != nil {
-                                               netc.conn.Close()
-                                       }
                                        netc.addr.Port = int(port)
-                                       netc.conn, err = net.ListenUDP("udp", netc.addr)
                                }
                                netc.mutex.Unlock()
-                               if err != nil {
-                                       logError.Println("Failed to create UDP listener:", err)
-                                       return &IPCError{Code: ipcErrorIO}
-                               }
+                               updateUDPConn(device)
+
                                // TODO: Clear source address of all peers
 
                        case "fwmark":
@@ -189,17 +217,30 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                device.mutex.RLock()
                                if device.publicKey.Equals(pubKey) {
+
+                                       // create dummy instance
+
+                                       peer = &Peer{}
+                                       dummy = true
                                        device.mutex.RUnlock()
-                                       logError.Println("Public key of peer matches private key of device")
-                                       return &IPCError{Code: ipcErrorInvalid}
-                               }
+                                       logInfo.Println("Ignoring peer with public key of device")
+
+                               } else {
+
+                                       // find peer referenced
 
-                               // find peer referenced
+                                       peer, _ = device.peers[pubKey]
+                                       device.mutex.RUnlock()
+                                       if peer == nil {
+                                               peer, err = device.NewPeer(pubKey)
+                                               if err != nil {
+                                                       logError.Println("Failed to create new peer:", err)
+                                                       return &IPCError{Code: ipcErrorInvalid}
+                                               }
+                                       }
+                                       signalSend(peer.signal.handshakeReset)
+                                       dummy = false
 
-                               peer, _ = device.peers[pubKey]
-                               device.mutex.RUnlock()
-                               if peer == nil {
-                                       peer = device.NewPeer(pubKey)
                                }
 
                        case "remove":
@@ -207,16 +248,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set remove, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               device.RemovePeer(peer.handshake.remoteStatic)
-                               logDebug.Println("Removing", peer.String())
-                               peer = nil
+                               if !dummy {
+                                       logDebug.Println("Removing", peer.String())
+                                       device.RemovePeer(peer.handshake.remoteStatic)
+                               }
+                               peer = &Peer{}
+                               dummy = true
 
                        case "preshared_key":
-                               err := func() error {
-                                       peer.mutex.Lock()
-                                       defer peer.mutex.Unlock()
-                                       return peer.handshake.presharedKey.FromHex(value)
-                               }()
+                               peer.mutex.Lock()
+                               err := peer.handshake.presharedKey.FromHex(value)
+                               peer.mutex.Unlock()
                                if err != nil {
                                        logError.Println("Failed to set preshared_key:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
@@ -232,6 +274,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                peer.mutex.Lock()
                                peer.endpoint = addr
                                peer.mutex.Unlock()
+                               signalSend(peer.signal.handshakeReset)
 
                        case "persistent_keepalive_interval":
 
@@ -251,12 +294,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                // send immediate keep-alive
 
                                if old == 0 && secs != 0 {
-                                       up, err := device.tun.IsUp()
                                        if err != nil {
                                                logError.Println("Failed to get tun device status:", err)
                                                return &IPCError{Code: ipcErrorIO}
                                        }
-                                       if up {
+                                       if atomic.LoadInt32(&device.isUp) == AtomicTrue && !dummy {
                                                peer.SendKeepAlive()
                                        }
                                }
@@ -266,7 +308,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               device.routingTable.RemovePeer(peer)
+                               if !dummy {
+                                       device.routingTable.RemovePeer(peer)
+                               }
 
                        case "allowed_ip":
                                _, network, err := net.ParseCIDR(value)
@@ -275,7 +319,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
                                ones, _ := network.Mask.Size()
-                               device.routingTable.Insert(network.IP, uint(ones), peer)
+                               if !dummy {
+                                       device.routingTable.Insert(network.IP, uint(ones), peer)
+                               }
 
                        default:
                                logError.Println("Invalid UAPI key (peer configuration):", key)
index f09ded62678db60fdc22079ce9e879a207d27a04..37603e8eb93b4408e0c45f19a470d168260c9601 100644 (file)
@@ -7,16 +7,15 @@ import (
 /* Specification constants */
 
 const (
-       RekeyAfterMessages      = (1 << 64) - (1 << 16) - 1
-       RejectAfterMessages     = (1 << 64) - (1 << 4) - 1
-       RekeyAfterTime          = time.Second * 120
-       RekeyAttemptTime        = time.Second * 90
-       RekeyTimeout            = time.Second * 5
-       RejectAfterTime         = time.Second * 180
-       KeepaliveTimeout        = time.Second * 10
-       CookieRefreshTime       = time.Second * 120
-       MaxHandshakeAttemptTime = time.Second * 90
-       PaddingMultiple         = 16
+       RekeyAfterMessages  = (1 << 64) - (1 << 16) - 1
+       RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+       RekeyAfterTime      = time.Second * 120
+       RekeyAttemptTime    = time.Second * 90
+       RekeyTimeout        = time.Second * 5
+       RejectAfterTime     = time.Second * 180
+       KeepaliveTimeout    = time.Second * 10
+       CookieRefreshTime   = time.Second * 120
+       PaddingMultiple     = 16
 )
 
 const (
@@ -33,4 +32,5 @@ const (
        QueueHandshakeBusySize = QueueHandshakeSize / 8
        MinMessageSize         = MessageTransportSize // size of keep-alive
        MaxMessageSize         = ((1 << 16) - 1) + MessageTransportHeaderSize
+       MaxPeers               = 1 << 16
 )
index 809c176de313d6f433b9dd5399b6859894ece9dc..730f89efa5d36689a2e53fad6e49956369a4cb0b 100644 (file)
@@ -7,6 +7,8 @@ import (
 /* Daemonizes the process on linux
  *
  * This is done by spawning and releasing a copy with the --foreground flag
+ *
+ * TODO: Use env variable to spawn in background
  */
 
 func Daemonize() error {
index de96f0b19073bb4aadbca712ae476e6cf3220a52..4aa90e345a8ff207a65847fe8577287e786f6449 100644 (file)
@@ -1,13 +1,10 @@
 package main
 
 import (
-       "errors"
-       "fmt"
        "net"
        "runtime"
        "sync"
        "sync/atomic"
-       "time"
 )
 
 type Device struct {
@@ -34,31 +31,45 @@ type Device struct {
        queue        struct {
                encryption chan *QueueOutboundElement
                decryption chan *QueueInboundElement
-               inbound    chan *QueueInboundElement
                handshake  chan QueueHandshakeElement
        }
        signal struct {
-               stop chan struct{}
+               stop       chan struct{} // halts all go routines
+               newUDPConn chan struct{} // a net.conn was set
        }
-       underLoad   int32 // used as an atomic bool
+       isUp        int32 // atomic bool: interface is up
+       underLoad   int32 // atomic bool: device is under load
        ratelimiter Ratelimiter
        peers       map[NoisePublicKey]*Peer
        mac         MACStateDevice
 }
 
+/* Warning:
+ * The caller must hold the device mutex (write lock)
+ */
+func removePeerUnsafe(device *Device, key NoisePublicKey) {
+       peer, ok := device.peers[key]
+       if !ok {
+               return
+       }
+       peer.mutex.Lock()
+       device.routingTable.RemovePeer(peer)
+       delete(device.peers, key)
+       peer.Close()
+}
+
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       // check if public key is matching any peer
+       // remove peers with matching public keys
 
        publicKey := sk.publicKey()
-       for _, peer := range device.peers {
+       for key, peer := range device.peers {
                h := &peer.handshake
                h.mutex.RLock()
                if h.remoteStatic.Equals(publicKey) {
-                       h.mutex.RUnlock()
-                       return errors.New("Private key matches public key of peer")
+                       removePeerUnsafe(device, key)
                }
                h.mutex.RUnlock()
        }
@@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
        // do DH precomputations
 
-       isZero := device.privateKey.IsZero()
+       rmKey := device.privateKey.IsZero()
 
-       for _, peer := range device.peers {
+       for key, peer := range device.peers {
                h := &peer.handshake
                h.mutex.Lock()
-               if isZero {
+               if rmKey {
                        h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
                } else {
                        h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+                       if isZero(h.precomputedStaticStatic[:]) {
+                               removePeerUnsafe(device, key)
+                       }
                }
-               fmt.Println(h.precomputedStaticStatic)
                h.mutex.Unlock()
        }
 
@@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
        device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
        device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
-       device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
 
        // prepare signals
 
        device.signal.stop = make(chan struct{})
+       device.signal.newUDPConn = make(chan struct{}, 1)
 
        // start workers
 
@@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        }
 
        go device.RoutineBusyMonitor()
-       go device.RoutineMTUUpdater()
-       go device.RoutineWriteToTUN()
        go device.RoutineReadFromTUN()
+       go device.RoutineTUNEventReader()
        go device.RoutineReceiveIncomming()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 
        return device
 }
 
-func (device *Device) RoutineMTUUpdater() {
+func (device *Device) RoutineTUNEventReader() {
+       events := device.tun.Events()
        logError := device.log.Error
-       for ; ; time.Sleep(5 * time.Second) {
 
-               // load updated MTU
-
-               mtu, err := device.tun.MTU()
-               if err != nil {
-                       logError.Println("Failed to load updated MTU of device:", err)
-                       continue
+       for event := range events {
+               if event&TUNEventMTUUpdate != 0 {
+                       mtu, err := device.tun.MTU()
+                       if err != nil {
+                               logError.Println("Failed to load updated MTU of device:", err)
+                       } else {
+                               if mtu+MessageTransportSize > MaxMessageSize {
+                                       mtu = MaxMessageSize - MessageTransportSize
+                               }
+                               atomic.StoreInt32(&device.mtu, int32(mtu))
+                       }
                }
 
-               // upper bound of mtu
+               if event&TUNEventUp != 0 {
+                       println("handle 1")
+                       atomic.StoreInt32(&device.isUp, AtomicTrue)
+                       updateUDPConn(device)
+                       println("handle 2", device.net.conn)
+               }
 
-               if mtu+MessageTransportSize > MaxMessageSize {
-                       mtu = MaxMessageSize - MessageTransportSize
+               if event&TUNEventDown != 0 {
+                       atomic.StoreInt32(&device.isUp, AtomicFalse)
+                       closeUDPConn(device)
                }
-               atomic.StoreInt32(&device.mtu, int32(mtu))
        }
 }
 
@@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
 func (device *Device) RemovePeer(key NoisePublicKey) {
        device.mutex.Lock()
        defer device.mutex.Unlock()
-
-       peer, ok := device.peers[key]
-       if !ok {
-               return
-       }
-       peer.mutex.Lock()
-       device.routingTable.RemovePeer(peer)
-       delete(device.peers, key)
-       peer.Close()
+       removePeerUnsafe(device, key)
 }
 
 func (device *Device) RemoveAllPeers() {
index beb5f7689c32019b7d11e26487327881d2fb8693..d55e18ffb272624876eba1b687274bb4baa5bfa1 100644 (file)
@@ -18,12 +18,13 @@ type MACStateDevice struct {
 }
 
 type MACStatePeer struct {
-       mutex     sync.RWMutex
-       cookieSet time.Time
-       cookie    [blake2s.Size128]byte
-       lastMAC1  [blake2s.Size128]byte // TODO: Check if set
-       keyMAC1   [blake2s.Size]byte
-       keyMAC2   [blake2s.Size]byte
+       mutex       sync.RWMutex
+       cookieSet   time.Time
+       cookie      [blake2s.Size128]byte
+       lastMAC1Set bool
+       lastMAC1    [blake2s.Size128]byte
+       keyMAC1     [blake2s.Size]byte
+       keyMAC2     [blake2s.Size]byte
 }
 
 /* Methods for verifing MAC fields
@@ -184,6 +185,10 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
        state.mutex.Lock()
        defer state.mutex.Unlock()
 
+       if !state.lastMAC1Set {
+               return false
+       }
+
        _, err := XChaCha20Poly1305Decrypt(
                cookie[:0],
                &msg.Nonce,
@@ -246,7 +251,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
                mac.Sum(mac1[:0])
        }()
        copy(state.lastMAC1[:], mac1)
-       // TODO: Set lastMac flag
+       state.lastMAC1Set = true
 
        // set mac2
 
index 9136959d6b4c192225a8027cf6795a932a0bc4f1..02aac3b9f534a57bf9e9348a6d8825673e2438a2 100644 (file)
@@ -9,16 +9,14 @@ import (
        "time"
 )
 
-const ()
-
 type Peer struct {
        id                          uint
        mutex                       sync.RWMutex
-       endpoint                    *net.UDPAddr
        persistentKeepaliveInterval uint64
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
+       endpoint                    *net.UDPAddr
        stats                       struct {
                txBytes           uint64 // bytes send to peer (endpoint)
                rxBytes           uint64 // bytes received from peer
@@ -34,6 +32,7 @@ type Peer struct {
                newKeyPair         chan struct{} // (size 1) : a new key pair was generated
                handshakeBegin     chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
                handshakeCompleted chan struct{} // (size 1) : handshake completed
+               handshakeReset     chan struct{} // (size 1) : reset handshake negotiation state
                flushNonceQueue    chan struct{} // (size 1) : empty queued packets
                messageSend        chan struct{} // (size 1) : a message was send to the peer
                messageReceived    chan struct{} // (size 1) : an authenticated message was received
@@ -44,6 +43,7 @@ type Peer struct {
                keepalivePassive    *time.Timer // set upon recieving messages
                newHandshake        *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
                zeroAllKeys         *time.Timer // zero all key material (after RejectAfterTime*3)
+               handshakeDeadline   *time.Timer // Current handshake must be completed
 
                pendingKeepalivePassive bool
                pendingNewHandshake     bool
@@ -59,7 +59,7 @@ type Peer struct {
        mac MACStatePeer
 }
 
-func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
+func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        // create peer
 
        peer := new(Peer)
@@ -80,11 +80,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.id = device.idCounter
        device.idCounter += 1
 
+       // check if over limit
+
+       if len(device.peers) >= MaxPeers {
+               return nil, errors.New("Too many peers")
+       }
+
        // map public key
 
        _, ok := device.peers[pk]
        if ok {
-               panic(errors.New("bug: adding existing peer"))
+               return nil, errors.New("Adding existing peer")
        }
        device.peers[pk] = peer
        device.mutex.Unlock()
@@ -108,6 +114,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.signal.stop = make(chan struct{})
        peer.signal.newKeyPair = make(chan struct{}, 1)
        peer.signal.handshakeBegin = make(chan struct{}, 1)
+       peer.signal.handshakeReset = make(chan struct{}, 1)
        peer.signal.handshakeCompleted = make(chan struct{}, 1)
        peer.signal.flushNonceQueue = make(chan struct{}, 1)
 
@@ -117,7 +124,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        go peer.RoutineSequentialSender()
        go peer.RoutineSequentialReceiver()
 
-       return peer
+       return peer, nil
 }
 
 func (peer *Peer) String() string {
index fb5c51fa5a5e02c1e53697642e922c8c1c314ace..5f469257e6946f3af07868d2bd2ab858aa469056 100644 (file)
@@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() {
 
 func (device *Device) RoutineReceiveIncomming() {
 
-       logInfo := device.log.Info
        logDebug := device.log.Debug
        logDebug.Println("Routine, receive incomming, started")
 
-       var buffer *[MaxMessageSize]byte
-
        for {
 
-               // check if stopped
+               // wait for new conn
+
+               var conn *net.UDPConn
 
                select {
+               case <-device.signal.newUDPConn:
+                       device.net.mutex.RLock()
+                       conn = device.net.conn
+                       device.net.mutex.RUnlock()
+
                case <-device.signal.stop:
                        return
-               default:
                }
 
-               // read next datagram
-
-               if buffer == nil {
-                       buffer = device.GetMessageBuffer()
-               }
-
-               // TODO: Take writelock to sleep
-               device.net.mutex.RLock()
-               conn := device.net.conn
-               device.net.mutex.RUnlock()
                if conn == nil {
-                       time.Sleep(time.Second)
                        continue
                }
 
-               // TODO: Wait for new conn or message
-               conn.SetReadDeadline(time.Now().Add(time.Second))
+               // receive datagrams until closed
 
-               size, raddr, err := conn.ReadFromUDP(buffer[:])
-               if err != nil || size < MinMessageSize {
-                       continue
-               }
+               buffer := device.GetMessageBuffer()
 
-               // handle packet
+               for {
 
-               packet := buffer[:size]
-               msgType := binary.LittleEndian.Uint32(packet[:4])
+                       // read next datagram
 
-               func() {
-                       switch msgType {
-
-                       case MessageInitiationType, MessageResponseType:
-
-                               // TODO: Check size early
+                       size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken
 
-                               // add to handshake queue
+                       if err != nil {
+                               break
+                       }
 
-                               device.addToHandshakeQueue(
-                                       device.queue.handshake,
-                                       QueueHandshakeElement{
-                                               msgType: msgType,
-                                               buffer:  buffer,
-                                               packet:  packet,
-                                               source:  raddr,
-                                       },
-                               )
-                               buffer = nil
+                       if size < MinMessageSize {
+                               continue
+                       }
 
-                       case MessageCookieReplyType:
+                       // check size of packet
 
-                               // TODO: Queue all the things
+                       packet := buffer[:size]
+                       msgType := binary.LittleEndian.Uint32(packet[:4])
 
-                               // verify and update peer cookie state
+                       var okay bool
 
-                               if len(packet) != MessageCookieReplySize {
-                                       return
-                               }
+                       switch msgType {
 
-                               var reply MessageCookieReply
-                               reader := bytes.NewReader(packet)
-                               err := binary.Read(reader, binary.LittleEndian, &reply)
-                               if err != nil {
-                                       logDebug.Println("Failed to decode cookie reply")
-                                       return
-                               }
-                               device.ConsumeMessageCookieReply(&reply)
+                       // check if transport
 
                        case MessageTransportType:
 
-                               // lookup key pair
+                               // check size
 
-                               if len(packet) < MessageTransportSize {
-                                       return
+                               if len(packet) < MessageTransportType {
+                                       continue
                                }
 
+                               // lookup key pair
+
                                receiver := binary.LittleEndian.Uint32(
                                        packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
                                )
                                value := device.indices.Lookup(receiver)
                                keyPair := value.keyPair
                                if keyPair == nil {
-                                       return
+                                       continue
                                }
 
                                // check key-pair expiry
 
                                if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
-                                       return
+                                       continue
                                }
 
-                               // add to peer queue
+                               // create work element
 
                                peer := value.peer
                                elem := &QueueInboundElement{
@@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() {
                                device.addToInboundQueue(device.queue.decryption, elem)
                                device.addToInboundQueue(peer.queue.inbound, elem)
                                buffer = nil
+                               continue
 
-                       default:
-                               logInfo.Println("Got unknown message from:", raddr)
+                       // otherwise it is a handshake related packet
+
+                       case MessageInitiationType:
+                               okay = len(packet) == MessageInitiationSize
+
+                       case MessageResponseType:
+                               okay = len(packet) == MessageResponseSize
+
+                       case MessageCookieReplyType:
+                               okay = len(packet) == MessageCookieReplySize
                        }
-               }()
+
+                       if okay {
+                               device.addToHandshakeQueue(
+                                       device.queue.handshake,
+                                       QueueHandshakeElement{
+                                               msgType: msgType,
+                                               buffer:  buffer,
+                                               packet:  packet,
+                                               source:  raddr,
+                                       },
+                               )
+                               buffer = device.GetMessageBuffer()
+                       }
+               }
        }
 }
 
@@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() {
                        return
                }
 
-               func() {
+               // handle cookie fields and ratelimiting
 
-                       // verify mac1
+               switch elem.msgType {
 
-                       if !device.mac.CheckMAC1(elem.packet) {
-                               logDebug.Println("Received packet with invalid mac1")
+               case MessageCookieReplyType:
+
+                       // verify and update peer cookie state
+
+                       var reply MessageCookieReply
+                       reader := bytes.NewReader(elem.packet)
+                       err := binary.Read(reader, binary.LittleEndian, &reply)
+                       if err != nil {
+                               logDebug.Println("Failed to decode cookie reply")
                                return
                        }
+                       device.ConsumeMessageCookieReply(&reply)
+                       continue
 
-                       // verify mac2
+               case MessageInitiationType, MessageResponseType:
 
-                       busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
+                       // check mac fields and ratelimit
 
-                       if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
-                               sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
-                               reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
-                               if err != nil {
-                                       logError.Println("Failed to create cookie reply:", err)
-                                       return
-                               }
-                               // TODO: Use temp
-                               writer := bytes.NewBuffer(elem.packet[:0])
-                               binary.Write(writer, binary.LittleEndian, reply)
-                               elem.packet = writer.Bytes()
-                               _, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
-                               if err != nil {
-                                       logDebug.Println("Failed to send cookie reply:", err)
-                               }
+                       if !device.mac.CheckMAC1(elem.packet) {
+                               logDebug.Println("Received packet with invalid mac1")
                                return
                        }
 
-                       // ratelimit
-
-                       // TODO: Only ratelimit when busy
+                       busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
 
-                       if !device.ratelimiter.Allow(elem.source.IP) {
-                               return
+                       if busy {
+                               if !device.mac.CheckMAC2(elem.packet, elem.source) {
+                                       sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+                                       reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
+                                       if err != nil {
+                                               logError.Println("Failed to create cookie reply:", err)
+                                               return
+                                       }
+                                       writer := bytes.NewBuffer(temp[:0])
+                                       binary.Write(writer, binary.LittleEndian, reply)
+                                       _, err = device.net.conn.WriteToUDP(
+                                               writer.Bytes(),
+                                               elem.source,
+                                       )
+                                       if err != nil {
+                                               logDebug.Println("Failed to send cookie reply:", err)
+                                       }
+                                       continue
+                               }
+                               if !device.ratelimiter.Allow(elem.source.IP) {
+                                       continue
+                               }
                        }
 
-                       // handle messages
+               default:
+                       logError.Println("Invalid packet ended up in the handshake queue")
+                       continue
+               }
 
-                       switch elem.msgType {
-                       case MessageInitiationType:
+               // handle handshake initation/response content
 
-                               // unmarshal
+               switch elem.msgType {
+               case MessageInitiationType:
 
-                               if len(elem.packet) != MessageInitiationSize {
-                                       return
-                               }
+                       // unmarshal
 
-                               var msg MessageInitiation
-                               reader := bytes.NewReader(elem.packet)
-                               err := binary.Read(reader, binary.LittleEndian, &msg)
-                               if err != nil {
-                                       logError.Println("Failed to decode initiation message")
-                                       return
-                               }
+                       var msg MessageInitiation
+                       reader := bytes.NewReader(elem.packet)
+                       err := binary.Read(reader, binary.LittleEndian, &msg)
+                       if err != nil {
+                               logError.Println("Failed to decode initiation message")
+                               continue
+                       }
 
-                               // consume initiation
+                       // consume initiation
 
-                               peer := device.ConsumeMessageInitiation(&msg)
-                               if peer == nil {
-                                       logInfo.Println(
-                                               "Recieved invalid initiation message from",
-                                               elem.source.IP.String(),
-                                               elem.source.Port,
-                                       )
-                                       return
-                               }
+                       peer := device.ConsumeMessageInitiation(&msg)
+                       if peer == nil {
+                               logInfo.Println(
+                                       "Recieved invalid initiation message from",
+                                       elem.source.IP.String(),
+                                       elem.source.Port,
+                               )
+                               continue
+                       }
 
-                               // update timers
+                       // update timers
 
-                               peer.TimerAnyAuthenticatedPacketTraversal()
-                               peer.TimerAnyAuthenticatedPacketReceived()
+                       peer.TimerAnyAuthenticatedPacketTraversal()
+                       peer.TimerAnyAuthenticatedPacketReceived()
 
-                               // update endpoint
-                               // TODO: Add a race condition \s
+                       // update endpoint
+                       // TODO: Discover destination address also, only update on change
 
-                               peer.mutex.Lock()
-                               peer.endpoint = elem.source
-                               peer.mutex.Unlock()
+                       peer.mutex.Lock()
+                       peer.endpoint = elem.source
+                       peer.mutex.Unlock()
 
-                               // create response
+                       // create response
 
-                               response, err := device.CreateMessageResponse(peer)
-                               if err != nil {
-                                       logError.Println("Failed to create response message:", err)
-                                       return
-                               }
+                       response, err := device.CreateMessageResponse(peer)
+                       if err != nil {
+                               logError.Println("Failed to create response message:", err)
+                               continue
+                       }
 
-                               peer.TimerEphemeralKeyCreated()
-                               peer.NewKeyPair()
+                       peer.TimerEphemeralKeyCreated()
+                       peer.NewKeyPair()
 
-                               logDebug.Println("Creating response message for", peer.String())
+                       logDebug.Println("Creating response message for", peer.String())
 
-                               writer := bytes.NewBuffer(temp[:0])
-                               binary.Write(writer, binary.LittleEndian, response)
-                               packet := writer.Bytes()
-                               peer.mac.AddMacs(packet)
+                       writer := bytes.NewBuffer(temp[:0])
+                       binary.Write(writer, binary.LittleEndian, response)
+                       packet := writer.Bytes()
+                       peer.mac.AddMacs(packet)
 
-                               // send response
+                       // send response
 
-                               peer.SendBuffer(packet)
+                       _, err = peer.SendBuffer(packet)
+                       if err == nil {
                                peer.TimerAnyAuthenticatedPacketTraversal()
+                       }
 
-                       case MessageResponseType:
+               case MessageResponseType:
 
-                               // unmarshal
+                       // unmarshal
 
-                               if len(elem.packet) != MessageResponseSize {
-                                       return
-                               }
-
-                               var msg MessageResponse
-                               reader := bytes.NewReader(elem.packet)
-                               err := binary.Read(reader, binary.LittleEndian, &msg)
-                               if err != nil {
-                                       logError.Println("Failed to decode response message")
-                                       return
-                               }
+                       var msg MessageResponse
+                       reader := bytes.NewReader(elem.packet)
+                       err := binary.Read(reader, binary.LittleEndian, &msg)
+                       if err != nil {
+                               logError.Println("Failed to decode response message")
+                               continue
+                       }
 
-                               // consume response
+                       // consume response
 
-                               peer := device.ConsumeMessageResponse(&msg)
-                               if peer == nil {
-                                       logInfo.Println(
-                                               "Recieved invalid response message from",
-                                               elem.source.IP.String(),
-                                               elem.source.Port,
-                                       )
-                                       return
-                               }
+                       peer := device.ConsumeMessageResponse(&msg)
+                       if peer == nil {
+                               logInfo.Println(
+                                       "Recieved invalid response message from",
+                                       elem.source.IP.String(),
+                                       elem.source.Port,
+                               )
+                               continue
+                       }
 
-                               // update timers
+                       peer.TimerEphemeralKeyCreated()
 
-                               peer.TimerAnyAuthenticatedPacketTraversal()
-                               peer.TimerAnyAuthenticatedPacketReceived()
-                               peer.TimerHandshakeComplete()
+                       // update timers
 
-                               // derive key-pair
+                       peer.TimerAnyAuthenticatedPacketTraversal()
+                       peer.TimerAnyAuthenticatedPacketReceived()
+                       peer.TimerHandshakeComplete()
 
-                               peer.NewKeyPair()
-                               peer.SendKeepAlive()
+                       // derive key-pair
 
-                       default:
-                               logError.Println("Invalid message type in handshake queue")
-                       }
-               }()
+                       peer.NewKeyPair()
+                       peer.SendKeepAlive()
+               }
        }
 }
 
@@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
        device := peer.device
 
        logInfo := device.log.Info
+       logError := device.log.Error
        logDebug := device.log.Debug
        logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
 
@@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                // process packet
 
-               func() {
-                       if elem.IsDropped() {
-                               return
-                       }
-
-                       // check for replay
-
-                       if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
-                               return
-                       }
+               if elem.IsDropped() {
+                       continue
+               }
 
-                       peer.TimerAnyAuthenticatedPacketTraversal()
-                       peer.TimerAnyAuthenticatedPacketReceived()
-                       peer.KeepKeyFreshReceiving()
+               // check for replay
 
-                       // check if using new key-pair
+               if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+                       continue
+               }
 
-                       kp := &peer.keyPairs
-                       kp.mutex.Lock()
-                       if kp.next == elem.keyPair {
-                               peer.TimerHandshakeComplete()
-                               kp.previous = kp.current
-                               kp.current = kp.next
-                               kp.next = nil
-                       }
-                       kp.mutex.Unlock()
+               peer.TimerAnyAuthenticatedPacketTraversal()
+               peer.TimerAnyAuthenticatedPacketReceived()
+               peer.KeepKeyFreshReceiving()
 
-                       // check for keep-alive
+               // check if using new key-pair
 
-                       if len(elem.packet) == 0 {
-                               logDebug.Println("Received keep-alive from", peer.String())
-                               return
-                       }
-                       peer.TimerDataReceived()
+               kp := &peer.keyPairs
+               kp.mutex.Lock()
+               if kp.next == elem.keyPair {
+                       peer.TimerHandshakeComplete()
+                       kp.previous = kp.current
+                       kp.current = kp.next
+                       kp.next = nil
+               }
+               kp.mutex.Unlock()
 
-                       // verify source and strip padding
+               // check for keep-alive
 
-                       switch elem.packet[0] >> 4 {
-                       case ipv4.Version:
+               if len(elem.packet) == 0 {
+                       logDebug.Println("Received keep-alive from", peer.String())
+                       continue
+               }
+               peer.TimerDataReceived()
 
-                               // strip padding
+               // verify source and strip padding
 
-                               if len(elem.packet) < ipv4.HeaderLen {
-                                       return
-                               }
+               switch elem.packet[0] >> 4 {
+               case ipv4.Version:
 
-                               field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
-                               length := binary.BigEndian.Uint16(field)
-                               // TODO: check length of packet & NOT TOO SMALL either
-                               elem.packet = elem.packet[:length]
+                       // strip padding
 
-                               // verify IPv4 source
+                       if len(elem.packet) < ipv4.HeaderLen {
+                               continue
+                       }
 
-                               src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                               if device.routingTable.LookupIPv4(src) != peer {
-                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
-                                       return
-                               }
+                       field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+                       length := binary.BigEndian.Uint16(field)
+                       if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+                               continue
+                       }
 
-                       case ipv6.Version:
+                       elem.packet = elem.packet[:length]
 
-                               // strip padding
+                       // verify IPv4 source
 
-                               if len(elem.packet) < ipv6.HeaderLen {
-                                       return
-                               }
+                       src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+                       if device.routingTable.LookupIPv4(src) != peer {
+                               logInfo.Println("Packet with unallowed source IP from", peer.String())
+                               continue
+                       }
 
-                               field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
-                               length := binary.BigEndian.Uint16(field)
-                               length += ipv6.HeaderLen
-                               // TODO: check length of packet
-                               elem.packet = elem.packet[:length]
+               case ipv6.Version:
 
-                               // verify IPv6 source
+                       // strip padding
 
-                               src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                               if device.routingTable.LookupIPv6(src) != peer {
-                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
-                                       return
-                               }
+                       if len(elem.packet) < ipv6.HeaderLen {
+                               continue
+                       }
 
-                       default:
-                               logInfo.Println("Packet with invalid IP version from", peer.String())
-                               return
+                       field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+                       length := binary.BigEndian.Uint16(field)
+                       length += ipv6.HeaderLen
+                       if int(length) > len(elem.packet) {
+                               continue
                        }
 
-                       atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
-                       device.addToInboundQueue(device.queue.inbound, elem)
+                       elem.packet = elem.packet[:length]
 
-                       // TODO: move TUN write into per peer routine
-               }()
-       }
-}
+                       // verify IPv6 source
 
-func (device *Device) RoutineWriteToTUN() {
+                       src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+                       if device.routingTable.LookupIPv6(src) != peer {
+                               logInfo.Println("Packet with unallowed source IP from", peer.String())
+                               continue
+                       }
 
-       logError := device.log.Error
-       logDebug := device.log.Debug
-       logDebug.Println("Routine, sequential tun writer, started")
+               default:
+                       logInfo.Println("Packet with invalid IP version from", peer.String())
+                       continue
+               }
 
-       for {
-               select {
-               case <-device.signal.stop:
-                       return
-               case elem := <-device.queue.inbound:
-                       _, err := device.tun.Write(elem.packet)
-                       device.PutMessageBuffer(elem.buffer)
-                       if err != nil {
-                               logError.Println("Failed to write packet to TUN device:", err)
-                       }
+               // write to tun
+
+               atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+               _, err := device.tun.Write(elem.packet)
+               device.PutMessageBuffer(elem.buffer)
+               if err != nil {
+                       logError.Println("Failed to write packet to TUN device:", err)
                }
        }
 }
index fc3573264521e0193147628710c89af6da2969a8..cf1f018fe9835a6b93f9b40ae0210b5bd7f9610e 100644 (file)
@@ -168,8 +168,6 @@ func (device *Device) RoutineReadFromTUN() {
                        continue
                }
 
-               println(size, err)
-
                elem.packet = elem.packet[:size]
 
                // lookup peer
@@ -210,6 +208,7 @@ func (device *Device) RoutineReadFromTUN() {
 
                // insert into nonce/pre-handshake queue
 
+               signalSend(peer.signal.handshakeReset)
                addToOutboundQueue(peer.queue.nonce, elem)
                elem = nil
 
index 1be85f07cede46033eabacbfaf582ae6840ec4a2..ab2e7adf59fc53b8130288b6be0056a31ae6fdfb 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bytes"
        "encoding/binary"
        "golang.org/x/crypto/blake2s"
+       "math/rand"
        "sync/atomic"
        "time"
 )
@@ -16,12 +17,11 @@ func (peer *Peer) KeepKeyFreshSending() {
        if kp == nil {
                return
        }
-       if !kp.isInitiator {
-               return
-       }
        nonce := atomic.LoadUint64(&kp.sendNonce)
-       send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
-       if send {
+       if nonce > RekeyAfterMessages {
+               signalSend(peer.signal.handshakeBegin)
+       }
+       if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
                signalSend(peer.signal.handshakeBegin)
        }
 }
@@ -30,6 +30,7 @@ func (peer *Peer) KeepKeyFreshSending() {
  *
  */
 func (peer *Peer) KeepKeyFreshReceiving() {
+       // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
        kp := peer.keyPairs.Current()
        if kp == nil {
                return
@@ -108,7 +109,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
  * - First transport message under the "next" key
  */
 func (peer *Peer) TimerHandshakeComplete() {
-       timerStop(peer.timer.zeroAllKeys)
        atomic.StoreInt64(
                &peer.stats.lastHandshakeNano,
                time.Now().UnixNano(),
@@ -129,10 +129,7 @@ func (peer *Peer) TimerHandshakeComplete() {
  * upon failure to complete a handshake
  */
 func (peer *Peer) TimerEphemeralKeyCreated() {
-       if !peer.timer.pendingZeroAllKeys {
-               peer.timer.pendingZeroAllKeys = true
-               peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-       }
+       peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 }
 
 func (peer *Peer) RoutineTimerHandler() {
@@ -154,19 +151,19 @@ func (peer *Peer) RoutineTimerHandler() {
 
                        interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
                        if interval > 0 {
-                               logDebug.Println("Sending persistent keep-alive to", peer.String())
+                               logDebug.Println("Sending keep-alive to", peer.String())
                                peer.SendKeepAlive()
                        }
 
                case <-peer.timer.keepalivePassive.C:
 
-                       logDebug.Println("Sending passive keep-alive to", peer.String())
+                       logDebug.Println("Sending keep-alive to", peer.String())
 
                        peer.SendKeepAlive()
 
                        if peer.timer.needAnotherKeepalive {
                                peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
-                               peer.timer.needAnotherKeepalive = true
+                               peer.timer.needAnotherKeepalive = false
                        }
 
                // unresponsive session
@@ -189,8 +186,6 @@ func (peer *Peer) RoutineTimerHandler() {
                        kp := &peer.keyPairs
                        kp.mutex.Lock()
 
-                       peer.timer.pendingZeroAllKeys = false
-
                        // unmap indecies
 
                        indices.mutex.Lock()
@@ -251,40 +246,41 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                        return
                }
 
-               // wait for handshake
+               // set deadline
+
+       BeginHandshakes:
 
-               deadline := time.Now().Add(MaxHandshakeAttemptTime)
+               signalClear(peer.signal.handshakeReset)
+               deadline := time.NewTimer(RekeyAttemptTime)
+
+       AttemptHandshakes:
 
-       Loop:
                for attempts := uint(1); ; attempts++ {
 
-                       // clear completed signal
+                       // check if deadline reached
 
                        select {
-                       case <-peer.signal.handshakeCompleted:
+                       case <-deadline.C:
+                               logInfo.Println("Handshake negotiation timed out for:", peer.String())
+                               signalSend(peer.signal.flushNonceQueue)
+                               timerStop(peer.timer.keepalivePersistent)
+                               break
                        case <-peer.signal.stop:
                                return
                        default:
                        }
 
-                       // check if sufficient time for retry
-
-                       if deadline.Before(time.Now().Add(RekeyTimeout)) {
-                               logInfo.Println("Handshake negotiation timed out for", peer.String())
-                               signalSend(peer.signal.flushNonceQueue)
-                               timerStop(peer.timer.keepalivePersistent)
-                               timerStop(peer.timer.keepalivePassive)
-                               break Loop
-                       }
+                       signalClear(peer.signal.handshakeCompleted)
 
                        // create initiation message
 
                        msg, err := peer.device.CreateMessageInitiation(peer)
                        if err != nil {
                                logError.Println("Failed to create handshake initiation message:", err)
-                               break Loop
+                               break AttemptHandshakes
                        }
-                       peer.TimerEphemeralKeyCreated()
+
+                       jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
 
                        // marshal and send
 
@@ -299,14 +295,14 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                                        "Failed to send handshake initiation message to",
                                        peer.String(), ":", err,
                                )
-                               continue
+                               break
                        }
 
                        peer.TimerAnyAuthenticatedPacketTraversal()
 
-                       // set timeout
+                       // set handshake timeout
 
-                       timeout := time.NewTimer(RekeyTimeout)
+                       timeout := time.NewTimer(RekeyTimeout + jitter)
                        logDebug.Println(
                                "Handshake initiation attempt",
                                attempts, "sent to", peer.String(),
@@ -321,15 +317,19 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
                        case <-peer.signal.handshakeCompleted:
                                <-timeout.C
-                               break Loop
+                               break AttemptHandshakes
+
+                       case <-peer.signal.handshakeReset:
+                               <-timeout.C
+                               goto BeginHandshakes
 
                        case <-timeout.C:
+                               // TODO: Clear source address for peer
                                continue
-
                        }
                }
 
-               // allow new signal to be set
+               // clear signal set in the meantime
 
                signalClear(peer.signal.handshakeBegin)
        }
index d782bd57786bf1c541fb5196e4fe765cde01ac0f..1c4c281059faf7344340569762e1a604eeb4f608 100644 (file)
@@ -6,10 +6,19 @@ package main
 
 const DefaultMTU = 1420
 
+type TUNEvent int
+
+const (
+       TUNEventUp = 1 << iota
+       TUNEventDown
+       TUNEventMTUUpdate
+)
+
 type TUNDevice interface {
        Read([]byte) (int, error)  // read a packet from the device (without any additional headers)
        Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
-       IsUp() (bool, error)       // is the interface up?
        MTU() (int, error)         // returns the MTU of the device
        Name() string              // returns the current name
+       Events() chan TUNEvent     // returns a constant channel of events related to the device
+       Close() error              // stops the device and closes the event channel
 }
index d0e2f470592db6d177f1ef50cb0e035bb26f14b3..34f746a70b469960669988dd5e84f2690b345171 100644 (file)
@@ -16,11 +16,12 @@ import (
 const CloneDevicePath = "/dev/net/tun"
 
 type NativeTun struct {
-       fd   *os.File
-       name string
+       fd     *os.File
+       name   string
+       events chan TUNEvent
 }
 
-func (tun *NativeTun) IsUp() (bool, error) {
+func (tun *NativeTun) isUp() (bool, error) {
        inter, err := net.InterfaceByName(tun.name)
        return inter.Flags&net.FlagUp != 0, err
 }
@@ -111,6 +112,14 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
        return tun.fd.Read(d)
 }
 
+func (tun *NativeTun) Events() chan TUNEvent {
+       return tun.events
+}
+
+func (tun *NativeTun) Close() error {
+       return nil
+}
+
 func CreateTUN(name string) (TUNDevice, error) {
 
        // open clone device
@@ -146,10 +155,14 @@ func CreateTUN(name string) (TUNDevice, error) {
        newName := string(ifr[:])
        newName = newName[:strings.Index(newName, "\000")]
        device := &NativeTun{
-               fd:   fd,
-               name: newName,
+               fd:     fd,
+               name:   newName,
+               events: make(chan TUNEvent, 5),
        }
 
+       // TODO: Wait for device to be upped
+       device.events <- TUNEventUp
+
        // set default MTU
 
        err = device.setMTU(DefaultMTU)
index d6d78e733d39a2896464febcbb54cf45ba46b524..fd56b5a72364d9ce76a933de8244255ca0fc7287 100644 (file)
@@ -7,7 +7,6 @@ import (
        "net"
        "os"
        "path"
-       "time"
 )
 
 const (
@@ -26,9 +25,10 @@ const (
  */
 
 type UAPIListener struct {
-       listener net.Listener // unix socket listener
-       connNew  chan net.Conn
-       connErr  chan error
+       listener  net.Listener // unix socket listener
+       connNew   chan net.Conn
+       connErr   chan error
+       inotifyFd int
 }
 
 func (l *UAPIListener) Accept() (net.Conn, error) {
@@ -106,9 +106,28 @@ func NewUAPIListener(name string) (net.Listener, error) {
 
        // watch for deletion of socket
 
+       uapi.inotifyFd, err = unix.InotifyInit()
+       if err != nil {
+               return nil, err
+       }
+
+       _, err = unix.InotifyAddWatch(
+               uapi.inotifyFd,
+               socketPath,
+               unix.IN_ATTRIB|
+                       unix.IN_DELETE|
+                       unix.IN_DELETE_SELF,
+       )
+
+       if err != nil {
+               return nil, err
+       }
+
        go func(l *UAPIListener) {
-               for ; ; time.Sleep(time.Second) {
-                       if _, err := os.Stat(socketPath); os.IsNotExist(err) {
+               var buff [4096]byte
+               for {
+                       unix.Read(uapi.inotifyFd, buff[:])
+                       if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
                                l.connErr <- err
                                return
                        }