]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Terminate on interface deletion
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 13 Jul 2017 12:32:40 +0000 (14:32 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 13 Jul 2017 12:32:40 +0000 (14:32 +0200)
Program now terminates when the interface is removed
Increases the number of os threads (relevant for Go <1.5, not tested)
More consistent commenting
Improved logging (additional peer information)

src/constants.go
src/device.go
src/ip.go
src/main.go
src/peer.go
src/receive.go
src/send.go
src/timers.go
src/trie.go

index 03847418e845fd864992be5894c70bbf66dc8cbb..6b0d41459a2363773008612b6be855cbcbbdb39d 100644 (file)
@@ -29,6 +29,6 @@ const (
        QueueInboundSize       = 1024
        QueueHandshakeSize     = 1024
        QueueHandshakeBusySize = QueueHandshakeSize / 8
-       MinMessageSize         = MessageTransportSize // keep-alive
-       MaxMessageSize         = 4096                 // TODO: make depend on the MTU?
+       MinMessageSize         = MessageTransportSize // size of keep-alive
+       MaxMessageSize         = (1 << 16) - 1
 )
index a26cc7be12ab6840c8d3f3eb3a0be192ce5d77a9..b2725445e0a1dc7c494fc62415055e0422730b65 100644 (file)
@@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        }
 
        go device.RoutineBusyMonitor()
+       go device.RoutineWriteToTUN(tun)
        go device.RoutineReadFromTUN(tun)
        go device.RoutineReceiveIncomming()
-       go device.RoutineWriteToTUN(tun)
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 
        return device
@@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) Close() {
        device.RemoveAllPeers()
        close(device.signal.stop)
-       close(device.queue.encryption)
+}
+
+func (device *Device) Wait() {
+       <-device.signal.stop
 }
index 36beb9c8131901c520df7da27ea01c59a78f39cb..752a404af30a98ef319f15022f833cc6677b3c7f 100644 (file)
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,17 +5,13 @@ import (
 )
 
 const (
-       IPv4version           = 4
        IPv4offsetTotalLength = 2
        IPv4offsetSrc         = 12
        IPv4offsetDst         = IPv4offsetSrc + net.IPv4len
-       IPv4headerSize        = 20
 )
 
 const (
-       IPv6version             = 6
        IPv6offsetPayloadLength = 4
        IPv6offsetSrc           = 8
        IPv6offsetDst           = IPv6offsetSrc + net.IPv6len
-       IPv6headerSize          = 40
 )
index 50140e342ebdc6b60ffeab1c1e6174319d0b5a7c..dc27472266e56515c458bc65e54a8e1f8f0f75a9 100644 (file)
@@ -5,6 +5,7 @@ import (
        "log"
        "net"
        "os"
+       "runtime"
 )
 
 /* TODO: Fix logging
@@ -18,6 +19,10 @@ func main() {
        }
        deviceName := os.Args[1]
 
+       // increase number of go workers (for Go <1.5)
+
+       runtime.GOMAXPROCS(runtime.NumCPU())
+
        // open TUN device
 
        tun, err := CreateTUN(deviceName)
@@ -31,17 +36,21 @@ func main() {
 
        // start configuration lister
 
-       socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
-       l, err := net.Listen("unix", socketPath)
-       if err != nil {
-               log.Fatal("listen error:", err)
-       }
-
-       for {
-               conn, err := l.Accept()
+       go func() {
+               socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
+               l, err := net.Listen("unix", socketPath)
                if err != nil {
-                       log.Fatal("accept error:", err)
+                       log.Fatal("listen error:", err)
                }
-               go ipcHandle(device, conn)
-       }
+
+               for {
+                       conn, err := l.Accept()
+                       if err != nil {
+                               log.Fatal("accept error:", err)
+                       }
+                       go ipcHandle(device, conn)
+               }
+       }()
+
+       device.Wait()
 }
index c8dc5c0d7e83bd9b499a7e96906a7f6c0cc2e0de..408c605ab108ba711b235221d4b8d9961aaac166 100644 (file)
@@ -1,7 +1,9 @@
 package main
 
 import (
+       "encoding/base64"
        "errors"
+       "fmt"
        "net"
        "sync"
        "time"
@@ -38,9 +40,9 @@ type Peer struct {
                /* Both keep-alive timers acts as one (see timers.go)
                 * They are kept seperate to simplify the implementation.
                 */
-               keepalivePersistent      *time.Timer // set for persistent keepalives
-               keepaliveAcknowledgement *time.Timer // set upon recieving messages
-               zeroAllKeys              *time.Timer // zero all key material after RejectAfterTime*3
+               keepalivePersistent *time.Timer // set for persistent keepalives
+               keepalivePassive    *time.Timer // set upon recieving messages
+               zeroAllKeys         *time.Timer // zero all key material after RejectAfterTime*3
        }
        queue struct {
                nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
@@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.mac.Init(pk)
        peer.device = device
 
+       peer.timer.keepalivePassive = NewStoppedTimer()
        peer.timer.keepalivePersistent = NewStoppedTimer()
-       peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
        peer.timer.zeroAllKeys = NewStoppedTimer()
 
        peer.flags.keepaliveWaiting = AtomicFalse
@@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        return peer
 }
 
+func (peer *Peer) String() string {
+       return fmt.Sprintf(
+               "peer(%d %s %s)",
+               peer.id,
+               peer.endpoint.String(),
+               base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+       )
+}
+
 func (peer *Peer) Close() {
        close(peer.signal.stop)
 }
index 99089a9b6805f85013fae7f88fc0de844ea03263..3e649b6c0dc861444641eb5d2b6d3bb0caeab14b 100644 (file)
@@ -4,6 +4,8 @@ import (
        "bytes"
        "encoding/binary"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/net/ipv4"
+       "golang.org/x/net/ipv6"
        "net"
        "sync"
        "sync/atomic"
@@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
                                        return
                                }
 
-                               logDebug.Println("Creating response...")
+                               logDebug.Println("Creating response message for", peer.String())
 
                                outElem := device.NewOutboundElement()
                                writer := bytes.NewBuffer(outElem.data[:0])
@@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
        var elem *QueueInboundElement
 
        device := peer.device
+
+       logInfo := device.log.Info
        logDebug := device.log.Debug
        logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
 
@@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                        peer.KeepKeyFreshReceiving()
 
-                       // check if confirming handshake
+                       // check if using new key-pair
 
                        kp := &peer.keyPairs
                        kp.mutex.Lock()
@@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        // check for keep-alive
 
                        if len(elem.packet) == 0 {
+                               logDebug.Println("Received keep-alive from", peer.String())
                                return
                        }
 
                        // verify source and strip padding
 
                        switch elem.packet[0] >> 4 {
-                       case IPv4version:
+                       case ipv4.Version:
 
                                // strip padding
 
-                               if len(elem.packet) < IPv4headerSize {
+                               if len(elem.packet) < ipv4.HeaderLen {
                                        return
                                }
 
@@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
                                dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
                                if device.routingTable.LookupIPv4(dst) != peer {
+                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
                                        return
                                }
 
-                       case IPv6version:
+                       case ipv6.Version:
 
                                // strip padding
 
-                               if len(elem.packet) < IPv6headerSize {
+                               if len(elem.packet) < ipv6.HeaderLen {
                                        return
                                }
 
                                field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
                                length := binary.BigEndian.Uint16(field)
-                               length += IPv6headerSize
+                               length += ipv6.HeaderLen
                                elem.packet = elem.packet[:length]
 
                                // verify IPv6 source
 
                                dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
                                if device.routingTable.LookupIPv6(dst) != peer {
+                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
                                        return
                                }
 
                        default:
-                               logDebug.Println("Receieved packet with unknown IP version")
+                               logInfo.Println("Packet with invalid IP version from", peer.String())
                                return
                        }
 
@@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 }
 
 func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
+
        logError := device.log.Error
        logDebug := device.log.Debug
        logDebug.Println("Routine, sequential tun writer, started")
index 5ea9a8f8920a853a68fde7693473d90c3c4ef2e2..d8ddc823b19b19acb094f1a70e219559159cb87f 100644 (file)
@@ -3,6 +3,8 @@ package main
 import (
        "encoding/binary"
        "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/net/ipv4"
+       "golang.org/x/net/ipv6"
        "net"
        "sync"
        "sync/atomic"
@@ -21,28 +23,26 @@ import (
  * The functions in this file occure (roughly) in the order packets are processed.
  */
 
-/* A work unit
- *
- * The sequential consumers will attempt to take the lock,
- * workers release lock when they have completed work on the packet.
+/* The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work (encryption) on the packet.
  *
  * If the element is inserted into the "encryption queue",
- * the content is preceeded by enough "junk" to contain the header
+ * the content is preceeded by enough "junk" to contain the transport header
  * (to allow the construction of transport messages in-place)
  */
 type QueueOutboundElement struct {
        dropped int32
        mutex   sync.Mutex
-       data    [MaxMessageSize]byte
-       packet  []byte   // slice of "data" (always!)
-       nonce   uint64   // nonce for encryption
-       keyPair *KeyPair // key-pair for encryption
-       peer    *Peer    // related peer
+       data    [MaxMessageSize]byte // slice holding the packet data
+       packet  []byte               // slice of "data" (always!)
+       nonce   uint64               // nonce for encryption
+       keyPair *KeyPair             // key-pair for encryption
+       peer    *Peer                // related peer
 }
 
 func (peer *Peer) FlushNonceQueue() {
        elems := len(peer.queue.nonce)
-       for i := 0; i < elems; i += 1 {
+       for i := 0; i < elems; i++ {
                select {
                case <-peer.queue.nonce:
                default:
@@ -111,14 +111,18 @@ func addToEncryptionQueue(
  * Obs. Single instance per TUN device
  */
 func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+
        if tun == nil {
-               // dummy
                return
        }
 
        elem := device.NewOutboundElement()
 
-       device.log.Debug.Println("Routine, TUN Reader: started")
+       logDebug := device.log.Debug
+       logError := device.log.Error
+
+       logDebug.Println("Routine, TUN Reader: started")
+
        for {
                // read packet
 
@@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
                elem.packet = elem.data[MessageTransportHeaderSize:]
                size, err := tun.Read(elem.packet)
                if err != nil {
-                       device.log.Error.Println("Failed to read packet from TUN device:", err)
-                       continue
+
+                       // stop process
+
+                       logError.Println("Failed to read packet from TUN device:", err)
+                       device.Close()
+                       return
                }
+
                elem.packet = elem.packet[:size]
-               if len(elem.packet) < IPv4headerSize {
-                       device.log.Error.Println("Packet too short, length:", size)
+               if len(elem.packet) < ipv4.HeaderLen {
+                       logError.Println("Packet too short, length:", size)
                        continue
                }
 
@@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 
                var peer *Peer
                switch elem.packet[0] >> 4 {
-               case IPv4version:
+               case ipv4.Version:
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
                        peer = device.routingTable.LookupIPv4(dst)
 
-               case IPv6version:
+               case ipv6.Version:
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
                        peer = device.routingTable.LookupIPv6(dst)
 
                default:
-                       device.log.Debug.Println("Receieved packet with unknown IP version")
+                       logDebug.Println("Receieved packet with unknown IP version")
                }
 
                if peer == nil {
                        continue
                }
+
                if peer.endpoint == nil {
-                       device.log.Debug.Println("No known endpoint for peer", peer.id)
+                       logDebug.Println("No known endpoint for peer", peer.String())
                        continue
                }
 
@@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
 
        device := peer.device
        logDebug := device.log.Debug
-       logDebug.Println("Routine, nonce worker, started for peer", peer.id)
+       logDebug.Println("Routine, nonce worker, started for peer", peer.String())
 
        func() {
 
@@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
                                        }
                                }
                                signalSend(peer.signal.handshakeBegin)
-                               logDebug.Println("Waiting for key-pair, peer", peer.id)
+                               logDebug.Println("Awaiting key-pair for", peer.String())
 
                                select {
                                case <-peer.signal.newKeyPair:
-                                       logDebug.Println("Key-pair negotiated for peer", peer.id)
+                                       logDebug.Println("Key-pair negotiated for", peer.String())
                                        goto NextPacket
 
                                case <-peer.signal.flushNonceQueue:
-                                       logDebug.Println("Clearing queue for peer", peer.id)
+                                       logDebug.Println("Clearing queue for", peer.String())
                                        peer.FlushNonceQueue()
                                        elem = nil
                                        goto NextPacket
@@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
        device := peer.device
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, sequential sender, started for peer", peer.id)
+       logDebug.Println("Routine, sequential sender, started for", peer.String())
 
        for {
                select {
                case <-peer.signal.stop:
-                       logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
+                       logDebug.Println("Routine, sequential sender, stopped for", peer.String())
                        return
+
                case work := <-peer.queue.outbound:
                        work.mutex.Lock()
                        if work.IsDropped() {
@@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
                                defer peer.mutex.RUnlock()
 
                                if peer.endpoint == nil {
-                                       logDebug.Println("No endpoint for peer:", peer.id)
+                                       logDebug.Println("No endpoint for", peer.String())
                                        return
                                }
 
@@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
                                }
                                atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
 
-                               // reset keep-alive (passive keep-alives / acknowledgements)
+                               // reset keep-alive
 
                                peer.TimerResetKeepalive()
                        }()
index 63939556c8deceb7ed32533c9914dd953cccfd81..2e5046e0f4f9e0fd0a4a0e9b841bfe0a5abbd807 100644 (file)
@@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
  * - First transport message under the "next" key
  */
 func (peer *Peer) EventHandshakeComplete() {
-       peer.device.log.Debug.Println("Handshake completed")
+       peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
        peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
        signalSend(peer.signal.handshakeCompleted)
 }
@@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
 
        // stop acknowledgement timer
 
-       timerStop(peer.timer.keepaliveAcknowledgement)
+       timerStop(peer.timer.keepalivePassive)
        atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
 }
 
@@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
        device := peer.device
 
        logDebug := device.log.Debug
-       logDebug.Println("Routine, timer handler, started for peer", peer.id)
+       logDebug.Println("Routine, timer handler, started for peer", peer.String())
 
        for {
                select {
@@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
 
                case <-peer.timer.keepalivePersistent.C:
 
-                       logDebug.Println("Sending persistent keep-alive to peer", peer.id)
+                       logDebug.Println("Sending persistent keep-alive to", peer.String())
 
                        peer.SendKeepAlive()
                        peer.TimerResetKeepalive()
 
-               case <-peer.timer.keepaliveAcknowledgement.C:
+               case <-peer.timer.keepalivePassive.C:
 
-                       logDebug.Println("Sending passive persistent keep-alive to peer", peer.id)
+                       logDebug.Println("Sending passive persistent keep-alive to", peer.String())
 
                        peer.SendKeepAlive()
                        peer.TimerResetKeepalive()
@@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
 
                case <-peer.timer.zeroAllKeys.C:
 
-                       logDebug.Println("Clearing all key material for peer", peer.id)
+                       logDebug.Println("Clearing all key material for", peer.String())
 
                        // zero out key pairs
 
@@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
        var elem *QueueOutboundElement
 
+       logInfo := device.log.Info
        logError := device.log.Error
        logDebug := device.log.Debug
-       logDebug.Println("Routine, handshake initator, started for peer", peer.id)
+       logDebug.Println("Routine, handshake initator, started for", peer.String())
 
-       for run := true; run; {
-               var err error
-               var attempts uint
-               var deadline time.Time
+       for {
 
                // wait for signal
 
@@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
                // wait for handshake
 
-               run = func() bool {
-                       for {
+               func() {
+                       var err error
+                       var deadline time.Time
+                       for attempts := uint(1); ; attempts++ {
 
                                // clear completed signal
 
                                select {
                                case <-peer.signal.handshakeCompleted:
                                case <-peer.signal.stop:
-                                       return false
+                                       return
                                default:
                                }
 
@@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                                }
                                elem, err = peer.BeginHandshakeInitiation()
                                if err != nil {
-                                       logError.Println("Failed to create initiation message:", err)
-                                       break
+                                       logError.Println("Failed to create initiation message", err, "for", peer.String())
+                                       return
                                }
 
                                // set timeout
 
-                               attempts += 1
                                if attempts == 1 {
                                        deadline = time.Now().Add(MaxHandshakeAttemptTime)
                                }
                                timeout := time.NewTimer(RekeyTimeout)
-                               logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
+                               logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
 
                                // wait for handshake or timeout
 
                                select {
+
                                case <-peer.signal.stop:
-                                       return true
+                                       return
 
                                case <-peer.signal.handshakeCompleted:
                                        <-timeout.C
-                                       return true
+                                       return
 
                                case <-timeout.C:
-                                       logDebug.Println("Timeout")
-
-                                       // check if sufficient time for retry
-
                                        if deadline.Before(time.Now().Add(RekeyTimeout)) {
+                                               logInfo.Println("Handshake negotiation timed out for", peer.String())
                                                signalSend(peer.signal.flushNonceQueue)
                                                timerStop(peer.timer.keepalivePersistent)
-                                               timerStop(peer.timer.keepaliveAcknowledgement)
-                                               return true
+                                               timerStop(peer.timer.keepalivePassive)
+                                               return
                                        }
                                }
                        }
-                       return true
                }()
 
                signalClear(peer.signal.handshakeBegin)
index c2304b2f3d0087cecaebed7c295dfa81fecfa1da..e81b5b62c95409b2bb0c9071dff0f38b32e8d59e 100644 (file)
@@ -23,7 +23,8 @@ type Trie struct {
        bits  []byte
        peer  *Peer
 
-       // Index of "branching" bit
+       // index of "branching" bit
+
        bit_at_byte  uint
        bit_at_shift uint
 }
@@ -36,7 +37,7 @@ type Trie struct {
 func commonBits(ip1 net.IP, ip2 net.IP) uint {
        var i uint
        size := uint(len(ip1))
-       for i = 0; i < size; i += 1 {
+       for i = 0; i < size; i++ {
                v := ip1[i] ^ ip2[i]
                if v != 0 {
                        v >>= 1
@@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
                return node
        }
 
-       // Walk recursivly
+       // walk recursivly
 
        node.child[0] = node.child[0].RemovePeer(p)
        node.child[1] = node.child[1].RemovePeer(p)
@@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
                return node
        }
 
-       // Remove peer & merge
+       // remove peer & merge
 
        node.peer = nil
        if node.child[0] == nil {
@@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
 
 func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 
-       // At leaf
+       // at leaf
 
        if node == nil {
                return &Trie{
@@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
                }
        }
 
-       // Traverse deeper
+       // traverse deeper
 
        common := commonBits(node.bits, ip)
        if node.cidr <= cidr && common >= node.cidr {
@@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
                return node
        }
 
-       // Split node
+       // split node
 
        newNode := &Trie{
                bits:         ip,
@@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 
        cidr = min(cidr, common)
 
-       // Check for shorter prefix
+       // check for shorter prefix
 
        if newNode.cidr == cidr {
                bit := newNode.choose(node.bits)
@@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
                return newNode
        }
 
-       // Create new parent for node & newNode
+       // create new parent for node & newNode
 
        parent := &Trie{
                bits:         ip,