]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: change Peer.endpoint locking to reduce contention
authorJordan Whited <jordan@tailscale.com>
Tue, 21 Nov 2023 00:49:06 +0000 (16:49 -0800)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Dec 2023 15:34:09 +0000 (16:34 +0100)
Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/device.go
device/mobilequirks.go
device/peer.go
device/sticky_linux.go
device/timers.go
device/uapi.go

index f9557a075b50434feaec3a982bc1738527bef482..ca26d00c1e5bd7e080e8761e23715d6bb76c25c5 100644 (file)
@@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
        // clear cached source addresses
        device.peers.RLock()
        for _, peer := range device.peers.keyMap {
-               peer.Lock()
-               defer peer.Unlock()
-               if peer.endpoint != nil {
-                       peer.endpoint.ClearSrc()
-               }
+               peer.markEndpointSrcForClearing()
        }
        device.peers.RUnlock()
 
@@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error {
        // clear cached source addresses
        device.peers.RLock()
        for _, peer := range device.peers.keyMap {
-               peer.Lock()
-               defer peer.Unlock()
-               if peer.endpoint != nil {
-                       peer.endpoint.ClearSrc()
-               }
+               peer.markEndpointSrcForClearing()
        }
        device.peers.RUnlock()
 
index 4e5051d7e69e3f811f9860352015d2dc39764d24..0a0080efd8d014eb5bc28201f623aec23c33ea74 100644 (file)
@@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
        device.net.brokenRoaming = true
        device.peers.RLock()
        for _, peer := range device.peers.keyMap {
-               peer.Lock()
-               peer.disableRoaming = peer.endpoint != nil
-               peer.Unlock()
+               peer.endpoint.Lock()
+               peer.endpoint.disableRoaming = peer.endpoint.val != nil
+               peer.endpoint.Unlock()
        }
        device.peers.RUnlock()
 }
index 2fb5da62a53ea0ced889ace7b3cb06840f6fc7d1..47a2f14418a143dd48b0b5c7c38c5dbfd72afbcb 100644 (file)
@@ -17,17 +17,20 @@ import (
 
 type Peer struct {
        isRunning         atomic.Bool
-       sync.RWMutex      // Mostly protects endpoint, but is generally taken whenever we modify peer
        keypairs          Keypairs
        handshake         Handshake
        device            *Device
-       endpoint          conn.Endpoint
        stopping          sync.WaitGroup // routines pending stop
        txBytes           atomic.Uint64  // bytes send to peer (endpoint)
        rxBytes           atomic.Uint64  // bytes received from peer
        lastHandshakeNano atomic.Int64   // nano seconds since epoch
 
-       disableRoaming bool
+       endpoint struct {
+               sync.Mutex
+               val            conn.Endpoint
+               clearSrcOnTx   bool // signal to val.ClearSrc() prior to next packet transmission
+               disableRoaming bool
+       }
 
        timers struct {
                retransmitHandshake     *Timer
@@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        // create peer
        peer := new(Peer)
-       peer.Lock()
-       defer peer.Unlock()
 
        peer.cookieGenerator.Init(pk)
        peer.device = device
@@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        handshake.mutex.Unlock()
 
        // reset endpoint
-       peer.endpoint = nil
+       peer.endpoint.Lock()
+       peer.endpoint.val = nil
+       peer.endpoint.disableRoaming = false
+       peer.endpoint.clearSrcOnTx = false
+       peer.endpoint.Unlock()
 
        // init timers
        peer.timersInit()
@@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
                return nil
        }
 
-       peer.RLock()
-       defer peer.RUnlock()
-
-       if peer.endpoint == nil {
+       peer.endpoint.Lock()
+       endpoint := peer.endpoint.val
+       if endpoint == nil {
+               peer.endpoint.Unlock()
                return errors.New("no known endpoint for peer")
        }
+       if peer.endpoint.clearSrcOnTx {
+               endpoint.ClearSrc()
+               peer.endpoint.clearSrcOnTx = false
+       }
+       peer.endpoint.Unlock()
 
-       err := peer.device.net.bind.Send(buffers, peer.endpoint)
+       err := peer.device.net.bind.Send(buffers, endpoint)
        if err == nil {
                var totalLen uint64
                for _, b := range buffers {
@@ -267,10 +277,20 @@ func (peer *Peer) Stop() {
 }
 
 func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
-       if peer.disableRoaming {
+       peer.endpoint.Lock()
+       defer peer.endpoint.Unlock()
+       if peer.endpoint.disableRoaming {
+               return
+       }
+       peer.endpoint.clearSrcOnTx = false
+       peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+       peer.endpoint.Lock()
+       defer peer.endpoint.Unlock()
+       if peer.endpoint.val == nil {
                return
        }
-       peer.Lock()
-       peer.endpoint = endpoint
-       peer.Unlock()
+       peer.endpoint.clearSrcOnTx = true
 }
index f9230f8c3c363c2d556b37cc65aa58f83bc08846..6057ff12ab53f05c850c6f7c8534d55822fe06a5 100644 (file)
@@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
                                                                if !ok {
                                                                        break
                                                                }
-                                                               pePtr.peer.Lock()
-                                                               if &pePtr.peer.endpoint != pePtr.endpoint {
-                                                                       pePtr.peer.Unlock()
+                                                               pePtr.peer.endpoint.Lock()
+                                                               if &pePtr.peer.endpoint.val != pePtr.endpoint {
+                                                                       pePtr.peer.endpoint.Unlock()
                                                                        break
                                                                }
-                                                               if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
-                                                                       pePtr.peer.Unlock()
+                                                               if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+                                                                       pePtr.peer.endpoint.Unlock()
                                                                        break
                                                                }
-                                                               pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
-                                                               pePtr.peer.Unlock()
+                                                               pePtr.peer.endpoint.clearSrcOnTx = true
+                                                               pePtr.peer.endpoint.Unlock()
                                                        }
                                                        attr = attr[attrhdr.Len:]
                                                }
@@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
                                        device.peers.RLock()
                                        i := uint32(1)
                                        for _, peer := range device.peers.keyMap {
-                                               peer.RLock()
-                                               if peer.endpoint == nil {
-                                                       peer.RUnlock()
+                                               peer.endpoint.Lock()
+                                               if peer.endpoint.val == nil {
+                                                       peer.endpoint.Unlock()
                                                        continue
                                                }
-                                               nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
+                                               nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
                                                if nativeEP == nil {
-                                                       peer.RUnlock()
+                                                       peer.endpoint.Unlock()
                                                        continue
                                                }
                                                if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
-                                                       peer.RUnlock()
+                                                       peer.endpoint.Unlock()
                                                        break
                                                }
                                                nlmsg := struct {
@@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
                                                reqPeerLock.Lock()
                                                reqPeer[i] = peerEndpointPtr{
                                                        peer:     peer,
-                                                       endpoint: &peer.endpoint,
+                                                       endpoint: &peer.endpoint.val,
                                                }
                                                reqPeerLock.Unlock()
-                                               peer.RUnlock()
+                                               peer.endpoint.Unlock()
                                                i++
                                                _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
                                                if err != nil {
index e28732c42556aa6e9a67a8d2da240b7a138f4a51..d4a4ed4e5ff59f0c019bcc48b8d5c0ce0de03bef 100644 (file)
@@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
                peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
 
                /* We clear the endpoint address src address, in case this is the cause of trouble. */
-               peer.Lock()
-               if peer.endpoint != nil {
-                       peer.endpoint.ClearSrc()
-               }
-               peer.Unlock()
+               peer.markEndpointSrcForClearing()
 
                peer.SendHandshakeInitiation(true)
        }
@@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
 func expiredNewHandshake(peer *Peer) {
        peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
        /* We clear the endpoint address src address, in case this is the cause of trouble. */
-       peer.Lock()
-       if peer.endpoint != nil {
-               peer.endpoint.ClearSrc()
-       }
-       peer.Unlock()
+       peer.markEndpointSrcForClearing()
        peer.SendHandshakeInitiation(false)
 }
 
index 617dcd333885c76d6f8082ea888f870e8a5bf24f..d81dae3b65d0aae5872798927c4571310df6714a 100644 (file)
@@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
 
                for _, peer := range device.peers.keyMap {
                        // Serialize peer state.
-                       // Do the work in an anonymous function so that we can use defer.
-                       func() {
-                               peer.RLock()
-                               defer peer.RUnlock()
-
-                               keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
-                               keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
-                               sendf("protocol_version=1")
-                               if peer.endpoint != nil {
-                                       sendf("endpoint=%s", peer.endpoint.DstToString())
-                               }
-
-                               nano := peer.lastHandshakeNano.Load()
-                               secs := nano / time.Second.Nanoseconds()
-                               nano %= time.Second.Nanoseconds()
-
-                               sendf("last_handshake_time_sec=%d", secs)
-                               sendf("last_handshake_time_nsec=%d", nano)
-                               sendf("tx_bytes=%d", peer.txBytes.Load())
-                               sendf("rx_bytes=%d", peer.rxBytes.Load())
-                               sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
-
-                               device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
-                                       sendf("allowed_ip=%s", prefix.String())
-                                       return true
-                               })
-                       }()
+                       peer.handshake.mutex.RLock()
+                       keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+                       keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+                       peer.handshake.mutex.RUnlock()
+                       sendf("protocol_version=1")
+                       peer.endpoint.Lock()
+                       if peer.endpoint.val != nil {
+                               sendf("endpoint=%s", peer.endpoint.val.DstToString())
+                       }
+                       peer.endpoint.Unlock()
+
+                       nano := peer.lastHandshakeNano.Load()
+                       secs := nano / time.Second.Nanoseconds()
+                       nano %= time.Second.Nanoseconds()
+
+                       sendf("last_handshake_time_sec=%d", secs)
+                       sendf("last_handshake_time_nsec=%d", nano)
+                       sendf("tx_bytes=%d", peer.txBytes.Load())
+                       sendf("rx_bytes=%d", peer.rxBytes.Load())
+                       sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+
+                       device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+                               sendf("allowed_ip=%s", prefix.String())
+                               return true
+                       })
                }
        }()
 
@@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
                return
        }
        if peer.created {
-               peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
+               peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
        }
        if peer.device.isUp() {
                peer.Start()
@@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
                if err != nil {
                        return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
                }
-               peer.Lock()
-               defer peer.Unlock()
-               peer.endpoint = endpoint
+               peer.endpoint.Lock()
+               defer peer.endpoint.Unlock()
+               peer.endpoint.val = endpoint
 
        case "persistent_keepalive_interval":
                device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)