]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Rework of entire locking system
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 2 Feb 2018 15:40:14 +0000 (16:40 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 2 Feb 2018 15:40:14 +0000 (16:40 +0100)
Locking on the Device instance is now much more fined-grained,
seperating out the fields into "resources" st. most common interactions
only require a small number.

src/conn.go
src/device.go
src/noise_helpers.go
src/noise_protocol.go
src/peer.go
src/receive.go
src/send.go
src/timers.go
src/tun_linux.go
src/uapi.go

index c2f5deeb035882db514c72d4a8d2137798620611..fb30ec28df13869f51fe4e8601684e1cac91c86f 100644 (file)
@@ -65,12 +65,12 @@ func unsafeCloseBind(device *Device) error {
 }
 
 func (device *Device) BindUpdate() error {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
 
-       netc := &device.net
-       netc.mutex.Lock()
-       defer netc.mutex.Unlock()
+       device.net.mutex.Lock()
+       defer device.net.mutex.Unlock()
+
+       device.peers.mutex.Lock()
+       defer device.peers.mutex.Unlock()
 
        // close existing sockets
 
@@ -85,6 +85,7 @@ func (device *Device) BindUpdate() error {
                // bind to new port
 
                var err error
+               netc := &device.net
                netc.bind, netc.port, err = CreateBind(netc.port)
                if err != nil {
                        netc.bind = nil
@@ -100,12 +101,12 @@ func (device *Device) BindUpdate() error {
 
                // clear cached source addresses
 
-               for _, peer := range device.peers {
+               for _, peer := range device.peers.keyMap {
                        peer.mutex.Lock()
+                       defer peer.mutex.Unlock()
                        if peer.endpoint != nil {
                                peer.endpoint.ClearSrc()
                        }
-                       peer.mutex.Unlock()
                }
 
                // start receiving routines
@@ -120,10 +121,8 @@ func (device *Device) BindUpdate() error {
 }
 
 func (device *Device) BindClose() error {
-       device.mutex.Lock()
        device.net.mutex.Lock()
        err := unsafeCloseBind(device)
        device.net.mutex.Unlock()
-       device.mutex.Unlock()
        return err
 }
index f1c09c6bf8c694ee641d8be01d1884295973e4b2..0317b604723fb9b223858d72e8575a3f6c4dd15b 100644 (file)
@@ -9,106 +9,170 @@ import (
 )
 
 type Device struct {
-       isUp      AtomicBool // device is (going) up
-       isClosed  AtomicBool // device is closed? (acting as guard)
-       log       *Logger    // collection of loggers for levels
-       idCounter uint       // for assigning debug ids to peers
-       fwMark    uint32
-       tun       struct {
-               device TUNDevice
-               mtu    int32
-       }
+       isUp     AtomicBool // device is (going) up
+       isClosed AtomicBool // device is closed? (acting as guard)
+       log      *Logger
+
+       // synchronized resources (locks acquired in order)
+
        state struct {
                mutex    deadlock.Mutex
                changing AtomicBool
                current  bool
        }
-       pool struct {
-               messageBuffers sync.Pool
-       }
+
        net struct {
                mutex  deadlock.RWMutex
                bind   Bind   // bind interface
                port   uint16 // listening port
                fwmark uint32 // mark value (0 = disabled)
        }
-       mutex        deadlock.RWMutex
-       privateKey   NoisePrivateKey
-       publicKey    NoisePublicKey
-       routingTable RoutingTable
-       indices      IndexTable
-       queue        struct {
+
+       noise struct {
+               mutex      deadlock.RWMutex
+               privateKey NoisePrivateKey
+               publicKey  NoisePublicKey
+       }
+
+       routing struct {
+               mutex deadlock.RWMutex
+               table RoutingTable
+       }
+
+       peers struct {
+               mutex  deadlock.RWMutex
+               keyMap map[NoisePublicKey]*Peer
+       }
+
+       // unprotected / "self-synchronising resources"
+
+       indices IndexTable
+       mac     CookieChecker
+
+       rate struct {
+               underLoadUntil atomic.Value
+               limiter        Ratelimiter
+       }
+
+       pool struct {
+               messageBuffers sync.Pool
+       }
+
+       queue struct {
                encryption chan *QueueOutboundElement
                decryption chan *QueueInboundElement
                handshake  chan QueueHandshakeElement
        }
+
        signal struct {
                stop Signal
        }
-       underLoadUntil atomic.Value
-       ratelimiter    Ratelimiter
-       peers          map[NoisePublicKey]*Peer
-       mac            CookieChecker
+
+       tun struct {
+               device TUNDevice
+               mtu    int32
+       }
 }
 
-func deviceUpdateState(device *Device) {
+/* Converts the peer into a "zombie", which remains in the peer map,
+ * but processes no packets and does not exists in the routing table.
+ *
+ * Must hold:
+ *  device.peers.mutex : exclusive lock
+ *  device.routing     : exclusive lock
+ */
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
 
-       // check if state already being updated (guard)
+       // stop routing and processing of packets
 
-       if device.state.changing.Swap(true) {
-               return
+       device.routing.table.RemovePeer(peer)
+       peer.Stop()
+
+       // clean index table
+
+       kp := &peer.keyPairs
+       kp.mutex.Lock()
+
+       if kp.previous != nil {
+               device.indices.Delete(kp.previous.localIndex)
        }
 
-       // compare to current state of device
+       if kp.current != nil {
+               device.indices.Delete(kp.current.localIndex)
+       }
 
-       device.state.mutex.Lock()
+       if kp.next != nil {
+               device.indices.Delete(kp.next.localIndex)
+       }
 
-       newIsUp := device.isUp.Get()
+       kp.previous = nil
+       kp.current = nil
+       kp.next = nil
+       kp.mutex.Unlock()
 
-       if newIsUp == device.state.current {
-               device.state.mutex.Unlock()
-               device.state.changing.Set(false)
+       // remove from peer map
+
+       delete(device.peers.keyMap, key)
+}
+
+func deviceUpdateState(device *Device) {
+
+       // check if state already being updated (guard)
+
+       if device.state.changing.Swap(true) {
                return
        }
 
-       device.state.mutex.Unlock()
+       func() {
 
-       // change state of device
+               // compare to current state of device
 
-       switch newIsUp {
-       case true:
+               device.state.mutex.Lock()
+               defer device.state.mutex.Unlock()
 
-               // start listener
+               newIsUp := device.isUp.Get()
 
-               if err := device.BindUpdate(); err != nil {
-                       device.isUp.Set(false)
-                       break
+               if newIsUp == device.state.current {
+                       device.state.changing.Set(false)
+                       return
                }
 
-               // start every peer
+               // change state of device
 
-               for _, peer := range device.peers {
-                       peer.Start()
-               }
+               switch newIsUp {
+               case true:
+                       if err := device.BindUpdate(); err != nil {
+                               device.isUp.Set(false)
+                               break
+                       }
 
-       case false:
+                       device.peers.mutex.Lock()
+                       defer device.peers.mutex.Unlock()
 
-               // stop listening
+                       for _, peer := range device.peers.keyMap {
+                               peer.Start()
+                       }
 
-               device.BindClose()
+               case false:
+                       device.BindClose()
 
-               // stop every peer
+                       device.peers.mutex.Lock()
+                       defer device.peers.mutex.Unlock()
 
-               for _, peer := range device.peers {
-                       peer.Stop()
+                       for _, peer := range device.peers.keyMap {
+                               println("stopping peer")
+                               peer.Stop()
+                       }
                }
-       }
 
-       // update state variables
-       // and check for state change in the mean time
+               // update state variables
+
+               device.state.current = newIsUp
+               device.state.changing.Set(false)
+       }()
+
+       // check for state change in the mean time
 
-       device.state.current = newIsUp
-       device.state.changing.Set(false)
        deviceUpdateState(device)
 }
 
@@ -133,18 +197,6 @@ func (device *Device) Down() {
        deviceUpdateState(device)
 }
 
-/* Warning:
- * The caller must hold the device mutex (write lock)
- */
-func removePeerUnsafe(device *Device, key NoisePublicKey) {
-       peer, ok := device.peers[key]
-       if !ok {
-               return
-       }
-       device.routingTable.RemovePeer(peer)
-       delete(device.peers, key)
-}
-
 func (device *Device) IsUnderLoad() bool {
 
        // check if currently under load
@@ -152,54 +204,66 @@ func (device *Device) IsUnderLoad() bool {
        now := time.Now()
        underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
        if underLoad {
-               device.underLoadUntil.Store(now.Add(time.Second))
+               device.rate.underLoadUntil.Store(now.Add(time.Second))
                return true
        }
 
        // check if recently under load
 
-       until := device.underLoadUntil.Load().(time.Time)
+       until := device.rate.underLoadUntil.Load().(time.Time)
        return until.After(now)
 }
 
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
+
+       // lock required resources
+
+       device.noise.mutex.Lock()
+       defer device.noise.mutex.Unlock()
+
+       device.routing.mutex.Lock()
+       defer device.routing.mutex.Unlock()
+
+       device.peers.mutex.Lock()
+       defer device.peers.mutex.Unlock()
+
+       for _, peer := range device.peers.keyMap {
+               peer.handshake.mutex.RLock()
+               defer peer.handshake.mutex.RUnlock()
+       }
 
        // remove peers with matching public keys
 
        publicKey := sk.publicKey()
-       for key, peer := range device.peers {
-               h := &peer.handshake
-               h.mutex.RLock()
-               if h.remoteStatic.Equals(publicKey) {
-                       removePeerUnsafe(device, key)
+       for key, peer := range device.peers.keyMap {
+               if peer.handshake.remoteStatic.Equals(publicKey) {
+                       unsafeRemovePeer(device, peer, key)
                }
-               h.mutex.RUnlock()
        }
 
        // update key material
 
-       device.privateKey = sk
-       device.publicKey = publicKey
+       device.noise.privateKey = sk
+       device.noise.publicKey = publicKey
        device.mac.Init(publicKey)
 
-       // do DH pre-computations
+       // do static-static DH pre-computations
+
+       rmKey := device.noise.privateKey.IsZero()
 
-       rmKey := device.privateKey.IsZero()
+       for key, peer := range device.peers.keyMap {
+
+               hs := &peer.handshake
 
-       for key, peer := range device.peers {
-               h := &peer.handshake
-               h.mutex.Lock()
                if rmKey {
-                       h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+                       hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
                } else {
-                       h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
-                       if isZero(h.precomputedStaticStatic[:]) {
-                               removePeerUnsafe(device, key)
-                       }
+                       hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
+               }
+
+               if isZero(hs.precomputedStaticStatic[:]) {
+                       unsafeRemovePeer(device, peer, key)
                }
-               h.mutex.Unlock()
        }
 
        return nil
@@ -215,21 +279,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
 
 func NewDevice(tun TUNDevice, logger *Logger) *Device {
        device := new(Device)
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
 
        device.isUp.Set(false)
        device.isClosed.Set(false)
 
        device.log = logger
-       device.peers = make(map[NoisePublicKey]*Peer)
        device.tun.device = tun
+       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 
-       device.indices.Init()
-       device.ratelimiter.Init()
+       // initialize anti-DoS / anti-scanning features
+
+       device.rate.limiter.Init()
+       device.rate.underLoadUntil.Store(time.Time{})
 
-       device.routingTable.Reset()
-       device.underLoadUntil.Store(time.Time{})
+       // initialize noise & crypt-key routine
+
+       device.indices.Init()
+       device.routing.table.Reset()
 
        // setup buffer pool
 
@@ -264,36 +330,50 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
 
        go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
-       go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
+       go device.rate.limiter.RoutineGarbageCollector(device.signal.stop)
 
        return device
 }
 
 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
-       device.mutex.RLock()
-       defer device.mutex.RUnlock()
-       return device.peers[pk]
+       device.peers.mutex.RLock()
+       defer device.peers.mutex.RUnlock()
+
+       return device.peers.keyMap[pk]
 }
 
 func (device *Device) RemovePeer(key NoisePublicKey) {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
-       removePeerUnsafe(device, key)
+       device.noise.mutex.Lock()
+       defer device.noise.mutex.Unlock()
+
+       device.routing.mutex.Lock()
+       defer device.routing.mutex.Unlock()
+
+       device.peers.mutex.Lock()
+       defer device.peers.mutex.Unlock()
+
+       // stop peer and remove from routing
+
+       peer, ok := device.peers.keyMap[key]
+       if ok {
+               unsafeRemovePeer(device, peer, key)
+       }
 }
 
 func (device *Device) RemoveAllPeers() {
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
 
-       for key, peer := range device.peers {
-               peer.Stop()
-               peer, ok := device.peers[key]
-               if !ok {
-                       return
-               }
-               device.routingTable.RemovePeer(peer)
-               delete(device.peers, key)
+       device.routing.mutex.Lock()
+       defer device.routing.mutex.Unlock()
+
+       device.peers.mutex.Lock()
+       defer device.peers.mutex.Unlock()
+
+       for key, peer := range device.peers.keyMap {
+               println("rm", peer.String())
+               unsafeRemovePeer(device, peer, key)
        }
+
+       device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 }
 
 func (device *Device) Close() {
@@ -305,7 +385,6 @@ func (device *Device) Close() {
        device.tun.device.Close()
        device.BindClose()
        device.isUp.Set(false)
-       println("remove")
        device.RemoveAllPeers()
        device.log.Info.Println("Interface closed")
 }
index 24302c0c7ac04bdc13b620412b2516ff5f838a4f..1e2de5ff2dd98ada2e027d840372357615a262bc 100644 (file)
@@ -3,6 +3,7 @@ package main
 import (
        "crypto/hmac"
        "crypto/rand"
+       "crypto/subtle"
        "golang.org/x/crypto/blake2s"
        "golang.org/x/crypto/curve25519"
        "hash"
@@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
 }
 
 func isZero(val []byte) bool {
-       var acc byte
+       acc := 1
        for _, b := range val {
-               acc |= b
+               acc &= subtle.ConstantTimeByteEq(b, 0)
        }
-       return acc == 0
+       return acc == 1
 }
 
 func setZero(arr []byte) {
index 2f9e1d5ca7bd428a0e35ff23693bff863b7b65c9..d620a0d597367c10c6414c5bcd7d56099d89944c 100644 (file)
@@ -137,6 +137,10 @@ func init() {
 }
 
 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+
+       device.noise.mutex.Lock()
+       defer device.noise.mutex.Unlock()
+
        handshake := &peer.handshake
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
@@ -187,7 +191,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
                        ss[:],
                )
                aead, _ := chacha20poly1305.New(key[:])
-               aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
+               aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
        }()
        handshake.mixHash(msg.Static[:])
 
@@ -212,16 +216,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 }
 
 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
-       if msg.Type != MessageInitiationType {
-               return nil
-       }
-
        var (
                hash     [blake2s.Size]byte
                chainKey [blake2s.Size]byte
        )
 
-       mixHash(&hash, &InitialHash, device.publicKey[:])
+       if msg.Type != MessageInitiationType {
+               return nil
+       }
+
+       device.noise.mutex.RLock()
+       defer device.noise.mutex.RUnlock()
+
+       mixHash(&hash, &InitialHash, device.noise.publicKey[:])
        mixHash(&hash, &hash, msg.Ephemeral[:])
        mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
 
@@ -231,7 +238,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        var peerPK NoisePublicKey
        func() {
                var key [chacha20poly1305.KeySize]byte
-               ss := device.privateKey.sharedSecret(msg.Ephemeral)
+               ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
                KDF2(&chainKey, &key, chainKey[:], ss[:])
                aead, _ := chacha20poly1305.New(key[:])
                _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@@ -407,7 +414,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                }()
 
                func() {
-                       ss := device.privateKey.sharedSecret(msg.Ephemeral)
+                       ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
                        mixKey(&chainKey, &chainKey, ss[:])
                        setZero(ss[:])
                }()
index 5ad45118f6d0d9db717ceff189edb675612765c2..3b8f7cca631728612199b5e91c3e3e95b1432c67 100644 (file)
@@ -14,7 +14,6 @@ const (
 )
 
 type Peer struct {
-       id                          uint
        isRunning                   AtomicBool
        mutex                       deadlock.RWMutex
        persistentKeepaliveInterval uint64
@@ -22,17 +21,20 @@ type Peer struct {
        handshake                   Handshake
        device                      *Device
        endpoint                    Endpoint
-       stats                       struct {
+
+       stats struct {
                txBytes           uint64 // bytes send to peer (endpoint)
                rxBytes           uint64 // bytes received from peer
                lastHandshakeNano int64  // nano seconds since epoch
        }
+
        time struct {
                mutex         deadlock.RWMutex
                lastSend      time.Time // last send message
                lastHandshake time.Time // last completed handshake
                nextKeepalive time.Time
        }
+
        signal struct {
                newKeyPair         Signal // size 1, new key pair was generated
                handshakeCompleted Signal // size 1, handshake completed
@@ -41,7 +43,9 @@ type Peer struct {
                messageSend        Signal // size 1, message was send to peer
                messageReceived    Signal // size 1, authenticated message recv
        }
+
        timer struct {
+
                // state related to WireGuard timers
 
                keepalivePersistent Timer // set for persistent keepalives
@@ -54,17 +58,20 @@ type Peer struct {
                sendLastMinuteHandshake bool
                needAnotherKeepalive    bool
        }
+
        queue struct {
                nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
                outbound chan *QueueOutboundElement // sequential ordering of work
                inbound  chan *QueueInboundElement  // sequential ordering of work
        }
+
        routines struct {
                mutex    deadlock.Mutex // held when stopping / starting routines
                starting sync.WaitGroup // routines pending start
                stopping sync.WaitGroup // routines pending stop
                stop     Signal         // size 0, stop all goroutines in peer
        }
+
        mac CookieGenerator
 }
 
@@ -74,8 +81,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
                return nil, errors.New("Device closed")
        }
 
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
+       // lock resources
+
+       device.state.mutex.Lock()
+       defer device.state.mutex.Unlock()
+
+       device.noise.mutex.RLock()
+       defer device.noise.mutex.RUnlock()
+
+       device.peers.mutex.Lock()
+       defer device.peers.mutex.Unlock()
+
+       // check if over limit
+
+       if len(device.peers.keyMap) >= MaxPeers {
+               return nil, errors.New("Too many peers")
+       }
 
        // create peer
 
@@ -94,32 +115,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        peer.timer.handshakeDeadline = NewTimer()
        peer.timer.handshakeTimeout = NewTimer()
 
-       // assign id for debugging
-
-       peer.id = device.idCounter
-       device.idCounter += 1
-
-       // check if over limit
-
-       if len(device.peers) >= MaxPeers {
-               return nil, errors.New("Too many peers")
-       }
-
        // map public key
 
-       _, ok := device.peers[pk]
+       _, ok := device.peers.keyMap[pk]
        if ok {
                return nil, errors.New("Adding existing peer")
        }
-       device.peers[pk] = peer
+       device.peers.keyMap[pk] = peer
 
        // precompute DH
 
        handshake := &peer.handshake
        handshake.mutex.Lock()
        handshake.remoteStatic = pk
-       handshake.precomputedStaticStatic =
-               device.privateKey.sharedSecret(handshake.remoteStatic)
+       handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
        handshake.mutex.Unlock()
 
        // reset endpoint
@@ -134,11 +143,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        // start peer
 
-       peer.device.state.mutex.Lock()
        if peer.device.isUp.Get() {
                peer.Start()
        }
-       peer.device.state.mutex.Unlock()
 
        return peer, nil
 }
@@ -166,14 +173,12 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
 func (peer *Peer) String() string {
        if peer.endpoint == nil {
                return fmt.Sprintf(
-                       "peer(%d unknown %s)",
-                       peer.id,
+                       "peer(unknown %s)",
                        base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
                )
        }
        return fmt.Sprintf(
-               "peer(%d %s %s)",
-               peer.id,
+               "peer(%s %s)",
                peer.endpoint.DstToString(),
                base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
        )
@@ -181,8 +186,12 @@ func (peer *Peer) String() string {
 
 func (peer *Peer) Start() {
 
+       if peer.device.isClosed.Get() {
+               return
+       }
+
        peer.routines.mutex.Lock()
-       defer peer.routines.mutex.Lock()
+       defer peer.routines.mutex.Unlock()
 
        peer.device.log.Debug.Println("Starting:", peer.String())
 
@@ -222,7 +231,7 @@ func (peer *Peer) Start() {
 func (peer *Peer) Stop() {
 
        peer.routines.mutex.Lock()
-       defer peer.routines.mutex.Lock()
+       defer peer.routines.mutex.Unlock()
 
        peer.device.log.Debug.Println("Stopping:", peer.String())
 
index 5ad7c4b86412e79561d1177e6d4839c7da645134..1f44df2221fae5c7f76de86936650688390c5228 100644 (file)
@@ -372,7 +372,7 @@ func (device *Device) RoutineHandshake() {
 
                                // check ratelimiter
 
-                               if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
+                               if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
                                        continue
                                }
                        }
@@ -495,19 +495,23 @@ func (device *Device) RoutineHandshake() {
 
 func (peer *Peer) RoutineSequentialReceiver() {
 
+       defer peer.routines.stopping.Done()
+
        device := peer.device
 
        logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
-       logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
+       logDebug.Println("Routine, sequential receiver, started for peer", peer.String())
+
+       peer.routines.starting.Done()
 
        for {
 
                select {
 
                case <-peer.routines.stop.Wait():
-                       logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
+                       logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String())
                        return
 
                case elem := <-peer.queue.inbound:
@@ -581,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                // verify IPv4 source
 
                                src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                               if device.routingTable.LookupIPv4(src) != peer {
+                               if device.routing.table.LookupIPv4(src) != peer {
                                        logInfo.Println(
                                                "IPv4 packet with disallowed source address from",
                                                peer.String(),
@@ -609,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                // verify IPv6 source
 
                                src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                               if device.routingTable.LookupIPv6(src) != peer {
+                               if device.routing.table.LookupIPv6(src) != peer {
                                        logInfo.Println(
                                                "IPv6 packet with disallowed source address from",
                                                peer.String(),
index e0a546d8b46d308fcc837d54865677917e0b2b33..7488d3a280a511380969c91e0f8fadac0245c7f5 100644 (file)
@@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() {
                                continue
                        }
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
-                       peer = device.routingTable.LookupIPv4(dst)
+                       peer = device.routing.table.LookupIPv4(dst)
 
                case ipv6.Version:
                        if len(elem.packet) < ipv6.HeaderLen {
                                continue
                        }
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
-                       peer = device.routingTable.LookupIPv6(dst)
+                       peer = device.routing.table.LookupIPv6(dst)
 
                default:
                        logDebug.Println("Received packet with unknown IP version")
@@ -187,10 +187,14 @@ func (device *Device) RoutineReadFromTUN() {
 func (peer *Peer) RoutineNonce() {
        var keyPair *KeyPair
 
+       defer peer.routines.stopping.Done()
+
        device := peer.device
        logDebug := device.log.Debug
        logDebug.Println("Routine, nonce worker, started for peer", peer.String())
 
+       peer.routines.starting.Done()
+
        for {
        NextPacket:
                select {
index f1ed9c5bd378d08c22d8d72fb1f541fcac0629d4..2ef105e00a1a8d858777ea36464c756a602dd6ad 100644 (file)
@@ -303,7 +303,7 @@ func (peer *Peer) RoutineTimerHandler() {
                        err := peer.sendNewHandshake()\r
                        if err != nil {\r
                                logInfo.Println(\r
-                                       "Failed to send handshake to peer:", peer.String())\r
+                                       "Failed to send handshake to peer:", peer.String(), "(", err, ")")\r
                        }\r
 \r
                case <-peer.timer.handshakeDeadline.Wait():\r
@@ -326,7 +326,7 @@ func (peer *Peer) RoutineTimerHandler() {
                        err := peer.sendNewHandshake()\r
                        if err != nil {\r
                                logInfo.Println(\r
-                                       "Failed to send handshake to peer:", peer.String())\r
+                                       "Failed to send handshake to peer:", peer.String(), "(", err, ")")\r
                        }\r
 \r
                        peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)\r
index daa2462ad7650c27825cf4eec3074d0f57597f25..975616952f3191ebdf17762049daf2858ae578ae 100644 (file)
@@ -313,7 +313,7 @@ func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
        }
 
        go device.RoutineNetlinkListener()
-       go device.RoutineHackListener() // cross namespace
+       // go device.RoutineHackListener() // cross namespace
 
        // set default MTU
 
@@ -369,7 +369,7 @@ func CreateTUN(name string) (TUNDevice, error) {
        }
 
        go device.RoutineNetlinkListener()
-       go device.RoutineHackListener() // cross namespace
+       // go device.RoutineHackListener() // cross namespace
 
        // set default MTU
 
index 68ebe43111c46cd6763c9ee579e90f6bd09e3580..caaa49837fdd043a2e10bc759cdd47e1a70df7d3 100644 (file)
@@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 {
 
 func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
-       // create lines
+       device.log.Debug.Println("UAPI: Processing get operation")
 
-       device.mutex.RLock()
-       device.net.mutex.RLock()
+       // create lines
 
        lines := make([]string, 0, 100)
        send := func(line string) {
                lines = append(lines, line)
        }
 
-       if !device.privateKey.IsZero() {
-               send("private_key=" + device.privateKey.ToHex())
-       }
+       func() {
 
-       if device.net.port != 0 {
-               send(fmt.Sprintf("listen_port=%d", device.net.port))
-       }
+               // lock required resources
 
-       if device.net.fwmark != 0 {
-               send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
-       }
+               device.net.mutex.RLock()
+               defer device.net.mutex.RUnlock()
+
+               device.noise.mutex.RLock()
+               defer device.noise.mutex.RUnlock()
+
+               device.routing.mutex.RLock()
+               defer device.routing.mutex.RUnlock()
+
+               device.peers.mutex.Lock()
+               defer device.peers.mutex.Unlock()
+
+               // serialize device related values
+
+               if !device.noise.privateKey.IsZero() {
+                       send("private_key=" + device.noise.privateKey.ToHex())
+               }
+
+               if device.net.port != 0 {
+                       send(fmt.Sprintf("listen_port=%d", device.net.port))
+               }
+
+               if device.net.fwmark != 0 {
+                       send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+               }
 
-       for _, peer := range device.peers {
-               func() {
+               // serialize each peer state
+
+               for _, peer := range device.peers.keyMap {
                        peer.mutex.RLock()
                        defer peer.mutex.RUnlock()
+
                        send("public_key=" + peer.handshake.remoteStatic.ToHex())
                        send("preshared_key=" + peer.handshake.presharedKey.ToHex())
                        if peer.endpoint != nil {
@@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                atomic.LoadUint64(&peer.persistentKeepaliveInterval),
                        ))
 
-                       for _, ip := range device.routingTable.AllowedIPs(peer) {
+                       for _, ip := range device.routing.table.AllowedIPs(peer) {
                                send("allowed_ip=" + ip.String())
                        }
-               }()
-       }
 
-       device.net.mutex.RUnlock()
-       device.mutex.RUnlock()
+               }
+       }()
 
-       // send lines
+       // send lines (does not require resource locks)
 
        for _, line := range lines {
                _, err := socket.WriteString(line + "\n")
@@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
 func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        scanner := bufio.NewScanner(socket)
-       logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
 
@@ -130,6 +146,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set private_key:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
+                               logDebug.Println("UAPI: Updating device private key")
                                device.SetPrivateKey(sk)
 
                        case "listen_port":
@@ -144,6 +161,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // update port and rebind
 
+                               logDebug.Println("UAPI: Updating listen port")
+
                                device.net.mutex.Lock()
                                device.net.port = uint16(port)
                                device.net.mutex.Unlock()
@@ -170,6 +189,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
 
+                               logDebug.Println("UAPI: Updating fwmark")
+
                                device.net.mutex.Lock()
                                device.net.fwmark = uint32(fwmark)
                                device.net.mutex.Unlock()
@@ -181,6 +202,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "public_key":
                                // switch to peer configuration
+                               logDebug.Println("UAPI: Transition to peer configuration")
                                deviceConfig = false
 
                        case "replace_peers":
@@ -188,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set replace_peers, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
+                               logDebug.Println("UAPI: Removing all peers")
                                device.RemoveAllPeers()
 
                        default:
@@ -203,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        switch key {
 
                        case "public_key":
-                               var pubKey NoisePublicKey
-                               err := pubKey.FromHex(value)
+                               var publicKey NoisePublicKey
+                               err := publicKey.FromHex(value)
                                if err != nil {
                                        logError.Println("Failed to get peer by public_key:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
 
-                               // check if public key of peer equal to device
+                               // ignore peer with public key of device
 
-                               device.mutex.RLock()
-                               if device.publicKey.Equals(pubKey) {
-
-                                       // create dummy instance (not added to device)
+                               device.noise.mutex.RLock()
+                               equals := device.noise.publicKey.Equals(publicKey)
+                               device.noise.mutex.RUnlock()
 
+                               if equals {
                                        peer = &Peer{}
                                        dummy = true
-                                       device.mutex.RUnlock()
-                                       logInfo.Println("Ignoring peer with public key of device")
+                               }
 
-                               } else {
+                               // find peer referenced
 
-                                       // find peer referenced
+                               peer = device.LookupPeer(publicKey)
 
-                                       peer, _ = device.peers[pubKey]
-                                       device.mutex.RUnlock()
-                                       if peer == nil {
-                                               peer, err = device.NewPeer(pubKey)
-                                               if err != nil {
-                                                       logError.Println("Failed to create new peer:", err)
-                                                       return &IPCError{Code: ipcErrorInvalid}
-                                               }
+                               if peer == nil {
+                                       peer, err = device.NewPeer(publicKey)
+                                       if err != nil {
+                                               logError.Println("Failed to create new peer:", err)
+                                               return &IPCError{Code: ipcErrorInvalid}
                                        }
-                                       peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
-                                       dummy = false
-
+                                       logDebug.Println("UAPI: Created new peer:", peer.String())
                                }
 
+                               peer.mutex.Lock()
+                               peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+                               peer.mutex.Unlock()
+
                        case "remove":
 
                                // remove currently selected peer from device
@@ -249,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
                                if !dummy {
-                                       logDebug.Println("Removing", peer.String())
+                                       logDebug.Println("UAPI: Removing peer:", peer.String())
                                        device.RemovePeer(peer.handshake.remoteStatic)
                                }
                                peer = &Peer{}
@@ -259,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // update PSK
 
-                               peer.mutex.Lock()
+                               logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String())
+
+                               peer.handshake.mutex.Lock()
                                err := peer.handshake.presharedKey.FromHex(value)
-                               peer.mutex.Unlock()
+                               peer.handshake.mutex.Unlock()
+
                                if err != nil {
                                        logError.Println("Failed to set preshared_key:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
@@ -271,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // set endpoint destination
 
+                               logDebug.Println("UAPI: Updating endpoint for peer:", peer.String())
+
                                err := func() error {
                                        peer.mutex.Lock()
                                        defer peer.mutex.Unlock()
@@ -292,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                                // update keep-alive interval
 
+                               logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String())
+
                                secs, err := strconv.ParseUint(value, 10, 16)
                                if err != nil {
                                        logError.Println("Failed to set persistent_keepalive_interval:", err)
@@ -316,25 +344,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                        case "replace_allowed_ips":
+
+                               logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String())
+
                                if value != "true" {
                                        logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               if !dummy {
-                                       device.routingTable.RemovePeer(peer)
+
+                               if dummy {
+                                       continue
                                }
 
+                               device.routing.mutex.Lock()
+                               device.routing.table.RemovePeer(peer)
+                               device.routing.mutex.Unlock()
+
                        case "allowed_ip":
+
+                               logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String())
+
                                _, network, err := net.ParseCIDR(value)
                                if err != nil {
                                        logError.Println("Failed to set allowed_ip:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               ones, _ := network.Mask.Size()
-                               if !dummy {
-                                       device.routingTable.Insert(network.IP, uint(ones), peer)
+
+                               if dummy {
+                                       continue
                                }
 
+                               ones, _ := network.Mask.Size()
+                               device.routing.mutex.Lock()
+                               device.routing.table.Insert(network.IP, uint(ones), peer)
+                               device.routing.mutex.Unlock()
+
                        default:
                                logError.Println("Invalid UAPI key (peer configuration):", key)
                                return &IPCError{Code: ipcErrorInvalid}