]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
More consistent use of signal struct
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 1 Dec 2017 22:37:26 +0000 (23:37 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 1 Dec 2017 22:37:26 +0000 (23:37 +0100)
src/device.go
src/main.go
src/misc.go
src/ratelimiter.go
src/receive.go
src/send.go
src/signal.go
src/timers.go
src/trie.go

index a1ce802589a5b1fdc3c65a3725c792746ece41fd..a3461adca1b6a74293fffc610bf91440c8e97bac 100644 (file)
@@ -37,7 +37,7 @@ type Device struct {
                handshake  chan QueueHandshakeElement
        }
        signal struct {
-               stop chan struct{}
+               stop Signal
        }
        underLoadUntil atomic.Value
        ratelimiter    Ratelimiter
@@ -129,7 +129,6 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
 
 func NewDevice(tun TUNDevice, logger *Logger) *Device {
        device := new(Device)
-
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
@@ -160,7 +159,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
 
        // prepare signals
 
-       device.signal.stop = make(chan struct{})
+       device.signal.stop = NewSignal()
 
        // prepare net
 
@@ -174,9 +173,11 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
                go device.RoutineDecryption()
                go device.RoutineHandshake()
        }
+
        go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
+
        return device
 }
 
@@ -210,11 +211,11 @@ func (device *Device) Close() {
        }
        device.log.Info.Println("Closing device")
        device.RemoveAllPeers()
-       close(device.signal.stop)
-       closeBind(device)
+       device.signal.stop.Broadcast()
        device.tun.device.Close()
+       closeBind(device)
 }
 
-func (device *Device) WaitChannel() chan struct{} {
-       return device.signal.stop
+func (device *Device) Wait() chan struct{} {
+       return device.signal.stop.Wait()
 }
index e43176c863e551d5555ee68bbb1e5774e50de360..8bca78c9b38b1d6d6c17854542772a44d2a43f04 100644 (file)
@@ -8,6 +8,10 @@ import (
        "strconv"
 )
 
+import _ "net/http/pprof"
+import "net/http"
+import "log"
+
 const (
        ExitSetupSuccess = 0
        ExitSetupFailed  = 1
@@ -25,6 +29,10 @@ func printUsage() {
 
 func main() {
 
+       go func() {
+               log.Println(http.ListenAndServe("localhost:6060", nil))
+       }()
+
        // parse arguments
 
        var foreground bool
@@ -160,7 +168,6 @@ func main() {
 
        errs := make(chan error)
        term := make(chan os.Signal)
-       wait := device.WaitChannel()
 
        uapi, err := UAPIListen(interfaceName, fileUAPI)
 
@@ -183,9 +190,9 @@ func main() {
        signal.Notify(term, os.Interrupt)
 
        select {
-       case <-wait:
        case <-term:
        case <-errs:
+       case <-device.Wait():
        }
 
        // clean up
index b43e97ec20ac7c778adcd26aa56bf3be300acdd2..80e33f627626119e88c934f7265eae43e4377eb8 100644 (file)
@@ -2,12 +2,10 @@ package main
 
 import (
        "sync/atomic"
-       "time"
 )
 
-/* We use int32 as atomic bools
- * (since booleans are not natively supported by sync/atomic)
- */
+/* Atomic Boolean */
+
 const (
        AtomicFalse = int32(iota)
        AtomicTrue
@@ -37,6 +35,8 @@ func (a *AtomicBool) Set(val bool) {
        atomic.StoreInt32(&a.flag, flag)
 }
 
+/* Integer manipulation */
+
 func toInt32(n uint32) int32 {
        mask := uint32(1 << 31)
        return int32(-(n & mask) + (n & ^mask))
@@ -55,32 +55,3 @@ func minUint64(a uint64, b uint64) uint64 {
        }
        return a
 }
-
-func signalSend(c chan struct{}) {
-       select {
-       case c <- struct{}{}:
-       default:
-       }
-}
-
-func signalClear(c chan struct{}) {
-       select {
-       case <-c:
-       default:
-       }
-}
-
-func timerStop(timer *time.Timer) {
-       if !timer.Stop() {
-               select {
-               case <-timer.C:
-               default:
-               }
-       }
-}
-
-func NewStoppedTimer() *time.Timer {
-       timer := time.NewTimer(time.Hour)
-       timerStop(timer)
-       return timer
-}
index 4f8227ecb3af9f4b635be1ffafb8308474289f3f..6e5f005fa52fd6882c83fdc7486fa633d047acc9 100644 (file)
@@ -66,11 +66,11 @@ func (rate *Ratelimiter) GarbageCollectEntries() {
        rate.mutex.Unlock()
 }
 
-func (rate *Ratelimiter) RoutineGarbageCollector(stop chan struct{}) {
+func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) {
        timer := time.NewTimer(time.Second)
        for {
                select {
-               case <-stop:
+               case <-stop.Wait():
                        return
                case <-timer.C:
                        rate.GarbageCollectEntries()
index fd1993eab9afd0b05206289fcebc163207fb79f4..f650cc9d73259da3b1ddbcf60d78bf27eba0aed9 100644 (file)
@@ -93,6 +93,11 @@ func (device *Device) addToHandshakeQueue(
        }
 }
 
+/* Receives incoming datagrams for the device
+ *
+ * Every time the bind is updated a new routine is started for
+ * IPv4 and IPv6 (separately)
+ */
 func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
 
        logDebug := device.log.Debug
@@ -182,6 +187,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
                        device.addToDecryptionQueue(device.queue.decryption, elem)
                        device.addToInboundQueue(peer.queue.inbound, elem)
                        buffer = device.GetMessageBuffer()
+
                        continue
 
                // otherwise it is a fixed size & handshake related packet
@@ -220,7 +226,7 @@ func (device *Device) RoutineDecryption() {
 
        for {
                select {
-               case <-device.signal.stop:
+               case <-device.signal.stop.Wait():
                        logDebug.Println("Routine, decryption worker, stopped")
                        return
 
@@ -256,7 +262,7 @@ func (device *Device) RoutineDecryption() {
        }
 }
 
-/* Handles incomming packets related to handshake
+/* Handles incoming packets related to handshake
  */
 func (device *Device) RoutineHandshake() {
 
@@ -271,7 +277,7 @@ func (device *Device) RoutineHandshake() {
        for {
                select {
                case elem = <-device.queue.handshake:
-               case <-device.signal.stop:
+               case <-device.signal.stop.Wait():
                        return
                }
 
@@ -356,7 +362,7 @@ func (device *Device) RoutineHandshake() {
                        continue
                }
 
-               // handle handshake initation/response content
+               // handle handshake initiation/response content
 
                switch elem.msgType {
                case MessageInitiationType:
@@ -376,7 +382,7 @@ func (device *Device) RoutineHandshake() {
                        peer := device.ConsumeMessageInitiation(&msg)
                        if peer == nil {
                                logInfo.Println(
-                                       "Recieved invalid initiation message from",
+                                       "Received invalid initiation message from",
                                        elem.endpoint.DstToString(),
                                )
                                continue
@@ -449,7 +455,7 @@ func (device *Device) RoutineHandshake() {
                        peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
-                       logDebug.Println("Received handshake initation from", peer)
+                       logDebug.Println("Received handshake initiation from", peer)
 
                        peer.TimerEphemeralKeyCreated()
 
@@ -556,7 +562,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
                                if device.routingTable.LookupIPv4(src) != peer {
                                        logInfo.Println(
-                                               "IPv4 packet with unallowed source address from",
+                                               "IPv4 packet with disallowed source address from",
                                                peer.String(),
                                        )
                                        continue
@@ -584,7 +590,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
                                if device.routingTable.LookupIPv6(src) != peer {
                                        logInfo.Println(
-                                               "IPv6 packet with unallowed source address from",
+                                               "IPv6 packet with disallowed source address from",
                                                peer.String(),
                                        )
                                        continue
index 35a4a6e3b52491401d5689da8ae6839b85f73808..2919f2e1caf4475fdd27a9473f35454774a50248 100644 (file)
@@ -11,7 +11,7 @@ import (
        "time"
 )
 
-/* Handles outbound flow
+/* Outbound flow
  *
  * 1. TUN queue
  * 2. Routing (sequential)
@@ -19,17 +19,22 @@ import (
  * 4. Encryption (parallel)
  * 5. Transmission (sequential)
  *
- * The order of packets (per peer) is maintained.
- * The functions in this file occure (roughly) in the order packets are processed.
- */
-
-/* The sequential consumers will attempt to take the lock,
+ * The functions in this file occur (roughly) in the order in
+ * which the packets are processed.
+ *
+ * Locking, Producers and Consumers
+ *
+ * The order of packets (per peer) must be maintained,
+ * but encryption of packets happen out-of-order:
+ *
+ * The sequential consumers will attempt to take the lock,
  * workers release lock when they have completed work (encryption) on the packet.
  *
  * If the element is inserted into the "encryption queue",
- * the content is preceeded by enough "junk" to contain the transport header
+ * the content is preceded by enough "junk" to contain the transport header
  * (to allow the construction of transport messages in-place)
  */
+
 type QueueOutboundElement struct {
        dropped int32
        mutex   sync.Mutex
@@ -155,7 +160,7 @@ func (device *Device) RoutineReadFromTUN() {
                        peer = device.routingTable.LookupIPv6(dst)
 
                default:
-                       logDebug.Println("Receieved packet with unknown IP version")
+                       logDebug.Println("Received packet with unknown IP version")
                }
 
                if peer == nil {
@@ -249,7 +254,7 @@ func (device *Device) RoutineEncryption() {
                // fetch next element
 
                select {
-               case <-device.signal.stop:
+               case <-device.signal.stop.Wait():
                        logDebug.Println("Routine, encryption worker, stopped")
                        return
 
index 96b21bbc20f4bad207aa96fad95a36dc319b79b6..2cefad460fd34e3ddfb5cfecfc6c6a04f5ea5885 100644 (file)
@@ -20,6 +20,8 @@ func (s *Signal) Enable() {
        s.enabled.Set(true)
 }
 
+/* Unblock exactly one listener
+ */
 func (s *Signal) Send() {
        if s.enabled.Get() {
                select {
@@ -29,6 +31,8 @@ func (s *Signal) Send() {
        }
 }
 
+/* Clear the signal if already fired
+ */
 func (s Signal) Clear() {
        select {
        case <-s.C:
@@ -36,10 +40,14 @@ func (s Signal) Clear() {
        }
 }
 
+/* Unblocks all listeners (forever)
+ */
 func (s Signal) Broadcast() {
-       close(s.C) // unblocks all selectors
+       close(s.C)
 }
 
+/* Wait for the signal
+ */
 func (s Signal) Wait() chan struct{} {
        return s.C
 }
index 64aeca86afb5a9209734ab740996588f6d7af484..ee47393cb28a6eb4a8b9dc1c5a375b47ea099573 100644 (file)
@@ -27,7 +27,7 @@ func (peer *Peer) KeepKeyFreshSending() {
 \r
 /* Called when a new authenticated message has been received\r
  *\r
- * NOTE: Not thread safe (called by sequential receiver)\r
+ * NOTE: Not thread safe, but called by sequential receiver!\r
  */\r
 func (peer *Peer) KeepKeyFreshReceiving() {\r
        if peer.timer.sendLastMinuteHandshake {\r
index 38fcd4a6e139a6faacd9650e7846d2220b8e2a85..405ffc325646ad924f0c00713cf4cd7e8c7b3e94 100644 (file)
@@ -11,10 +11,8 @@ import (
  * same way as those created by the "net" functions.
  * Here the IPs are slices of either 4 or 16 byte (not always 16)
  *
- * Syncronization done seperatly
+ * Synchronization done separately
  * See: routing.go
- *
- * TODO: Better commenting
  */
 
 type Trie struct {
@@ -30,7 +28,11 @@ type Trie struct {
 }
 
 /* Finds length of matching prefix
- * TODO: Make faster
+ *
+ * TODO: Only use during insertion (xor + prefix mask for lookup)
+ *       Check out
+ *       prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
+ *       https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
  *
  * Assumption:
  *       len(ip1) == len(ip2)
@@ -88,7 +90,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
                return node
        }
 
-       // walk recursivly
+       // walk recursively
 
        node.child[0] = node.child[0].RemovePeer(p)
        node.child[1] = node.child[1].RemovePeer(p)