]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: fix persistent_keepalive_interval data races
authorJosh Bleecher Snyder <josh@tailscale.com>
Mon, 14 Dec 2020 23:28:52 +0000 (15:28 -0800)
committerJosh Bleecher Snyder <josh@tailscale.com>
Wed, 16 Dec 2020 00:57:09 +0000 (16:57 -0800)
Co-authored-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
device/device.go
device/device_test.go
device/peer.go
device/timers.go
device/uapi.go

index d9367e5e24ed4f25d5e3d4e22fd0fe7c5143ae18..99f5e602d83ed744cdabd3c38066b7e2a5ff3338 100644 (file)
@@ -163,7 +163,7 @@ func deviceUpdateState(device *Device) {
                device.peers.RLock()
                for _, peer := range device.peers.keyMap {
                        peer.Start()
-                       if peer.persistentKeepaliveInterval > 0 {
+                       if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
                                peer.SendKeepalive()
                        }
                }
index 65942ecf2a301fcf4555538fda8794c08238efbd..e14391412ea72ad39b3bfc2d83d50229cb7ca069 100644 (file)
@@ -215,7 +215,20 @@ func TestConcurrencySafety(t *testing.T) {
        }()
        warmup.Wait()
 
-       // coming soon: more things here...
+       // Change persistent_keepalive_interval concurrently with tunnel use.
+       t.Run("persistentKeepaliveInterval", func(t *testing.T) {
+               cfg := uapiCfg(
+                       "public_key", "f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725",
+                       "persistent_keepalive_interval", "1",
+               )
+               for i := 0; i < 1000; i++ {
+                       cfg.Seek(0, io.SeekStart)
+                       err := pair[0].dev.IpcSetOperation(cfg)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+               }
+       })
 
        close(done)
 }
index c2397cc9dc6142dd579778c14f81f3eba3758b46..31b75c7b3375871866bd063c7bf8958f17914f92 100644 (file)
@@ -27,7 +27,7 @@ type Peer struct {
        handshake                   Handshake
        device                      *Device
        endpoint                    conn.Endpoint
-       persistentKeepaliveInterval uint16
+       persistentKeepaliveInterval uint32 // accessed atomically
        disableRoaming              bool
 
        // These fields are accessed with atomic operations, which must be
index 48cef94b2269932f7b7b6487b2ec494da8674992..e94da3654757e05eabf75344298a32ded003df58 100644 (file)
@@ -138,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
 }
 
 func expiredPersistentKeepalive(peer *Peer) {
-       if peer.persistentKeepaliveInterval > 0 {
+       if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
                peer.SendKeepalive()
        }
 }
@@ -201,8 +201,9 @@ func (peer *Peer) timersSessionDerived() {
 
 /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
 func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
-       if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
-               peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+       keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+       if keepalive > 0 && peer.timersActive() {
+               peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
        }
 }
 
index c0e522b16abdedec2bf58ce5c8fc38e1acc064d1..3f26607bb503ff71a96d6ce2152dfa04636d96c3 100644 (file)
@@ -86,7 +86,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
                        send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
                        send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
                        send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
-                       send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+                       send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)))
 
                        for _, ip := range device.allowedips.EntriesForPeer(peer) {
                                send("allowed_ip=" + ip.String())
@@ -333,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                        return &IPCError{ipc.IpcErrorInvalid}
                                }
 
-                               old := peer.persistentKeepaliveInterval
-                               peer.persistentKeepaliveInterval = uint16(secs)
+                               old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
 
                                // send immediate keepalive if we're turning it on and before it wasn't on