]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Finer-grained start-stop synchronization
authorJason A. Donenfeld <Jason@zx2c4.com>
Wed, 16 May 2018 20:20:15 +0000 (22:20 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 16 May 2018 20:20:15 +0000 (22:20 +0200)
conn.go
device.go
peer.go
receive.go
send.go
tun.go

diff --git a/conn.go b/conn.go
index 92f4cfe8808734a72f67a1db049afd9a5f59e474..d3919ca29b8073cffa35675c103400da134edf57 100644 (file)
--- a/conn.go
+++ b/conn.go
@@ -12,6 +12,10 @@ import (
        "net"
 )
 
+const (
+       ConnRoutineNumber = 2
+)
+
 /* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
  */
 type Bind interface {
@@ -153,6 +157,8 @@ func (device *Device) BindUpdate() error {
 
                // start receiving routines
 
+               device.state.starting.Add(ConnRoutineNumber)
+               device.state.stopping.Add(ConnRoutineNumber)
                go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
                go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
 
index e91ca72d9217bfd0bad84872faf9bad7daad563b..6e1bc94e8377c21c23a98b8970a6b953626edafb 100644 (file)
--- a/device.go
+++ b/device.go
@@ -15,6 +15,7 @@ import (
 
 const (
        DeviceRoutineNumberPerCPU = 3
+       DeviceRoutineNumberAdditional = 2
 )
 
 type Device struct {
@@ -25,6 +26,7 @@ type Device struct {
        // synchronized resources (locks acquired in order)
 
        state struct {
+               starting sync.WaitGroup
                stopping sync.WaitGroup
                mutex    sync.Mutex
                changing AtomicBool
@@ -297,7 +299,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
        // start workers
 
        cpus := runtime.NumCPU()
-       device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus)
+       device.state.starting.Wait()
+       device.state.stopping.Wait()
+       device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional)
+       device.state.starting.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional)
        for i := 0; i < cpus; i += 1 {
                go device.RoutineEncryption()
                go device.RoutineDecryption()
@@ -307,6 +312,8 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
        go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
 
+       device.state.starting.Wait()
+
        return device
 }
 
@@ -363,6 +370,9 @@ func (device *Device) Close() {
        if device.isClosed.Swap(true) {
                return
        }
+
+       device.state.starting.Wait()
+
        device.log.Info.Println("Device closing")
        device.state.changing.Set(true)
        device.state.mutex.Lock()
diff --git a/peer.go b/peer.go
index 4bc1adafc421175b836d2a0eb31aac94e8dbb13b..3808ad64f835f634a4d8eb23e42cb8a1a9f78a38 100644 (file)
--- a/peer.go
+++ b/peer.go
@@ -231,20 +231,21 @@ func (peer *Peer) Stop() {
 
        // prevent simultaneous start/stop operations
 
-       peer.routines.mutex.Lock()
-       defer peer.routines.mutex.Unlock()
-
        if !peer.isRunning.Swap(false) {
                return
        }
 
+       peer.routines.starting.Wait()
+
+       peer.routines.mutex.Lock()
+       defer peer.routines.mutex.Unlock()
+
        peer.device.log.Debug.Println(peer, ": Stopping...")
 
        peer.timersStop()
 
        // stop & wait for ongoing peer routines
 
-       peer.routines.starting.Wait()
        close(peer.routines.stop)
        peer.routines.stopping.Wait()
 
index 77062fa6a2c8f1dde1ee6b546ae005901102ba4d..aa96057bdf7534ad4caeb53ed3dff6862aac40ee 100644 (file)
@@ -124,9 +124,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
        logDebug := device.log.Debug
        defer func() {
                logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+               device.state.stopping.Done()
        }()
 
        logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting")
+       device.state.starting.Done()
 
        // receive datagrams until conn is closed
 
@@ -257,6 +259,7 @@ func (device *Device) RoutineDecryption() {
                device.state.stopping.Done()
        }()
        logDebug.Println("Routine: decryption worker - started")
+       device.state.starting.Done()
 
        for {
                select {
@@ -324,6 +327,7 @@ func (device *Device) RoutineHandshake() {
        }()
 
        logDebug.Println("Routine: handshake worker - started")
+       device.state.starting.Done()
 
        var elem QueueHandshakeElement
        var ok bool
diff --git a/send.go b/send.go
index 9a59abd8fbfe69dde54682c9acc5c3eb1acbae5a..5605ad11eab9efe943f924df93c76f33f6a07490 100644 (file)
--- a/send.go
+++ b/send.go
@@ -247,9 +247,11 @@ func (device *Device) RoutineReadFromTUN() {
 
        defer func() {
                logDebug.Println("Routine: TUN reader - stopped")
+               device.state.stopping.Done()
        }()
 
        logDebug.Println("Routine: TUN reader - started")
+       device.state.starting.Done()
 
        for {
 
@@ -424,6 +426,7 @@ func (device *Device) RoutineEncryption() {
        }()
 
        logDebug.Println("Routine: encryption worker - started")
+       device.state.starting.Done()
 
        for {
 
diff --git a/tun.go b/tun.go
index ec3ab479cc14cda533613a6bcc123b6db295fa86..ef80625d4e619fe24844f51d3b40819a99b94fd7 100644 (file)
--- a/tun.go
+++ b/tun.go
@@ -35,6 +35,8 @@ func (device *Device) RoutineTUNEventReader() {
        logInfo := device.log.Info
        logError := device.log.Error
 
+       device.state.starting.Done()
+
        for event := range device.tun.device.Events() {
                if event&TUNEventMTUUpdate != 0 {
                        mtu, err := device.tun.device.MTU()
@@ -63,4 +65,6 @@ func (device *Device) RoutineTUNEventReader() {
                        device.Down()
                }
        }
+
+       device.state.stopping.Done()
 }