]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
all: use Go 1.19 and its atomic types
authorBrad Fitzpatrick <bradfitz@tailscale.com>
Tue, 30 Aug 2022 14:43:11 +0000 (07:43 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 4 Sep 2022 10:57:30 +0000 (12:57 +0200)
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
20 files changed:
conn/bind_windows.go
device/alignment_test.go [deleted file]
device/device.go
device/device_test.go
device/keypair.go
device/misc.go [deleted file]
device/noise-protocol.go
device/noise_test.go
device/peer.go
device/pools.go
device/pools_test.go
device/receive.go
device/send.go
device/timers.go
device/tun.go
device/uapi.go
go.mod
ipc/namedpipe/file.go
ipc/namedpipe/namedpipe.go
tun/tun_windows.go

index 9268bc15fe7d257ba504cdaaad9f3d4c81c1b412..c066efa4e99cb314835f9305fb619434d5ff7101 100644 (file)
@@ -74,7 +74,7 @@ type afWinRingBind struct {
 type WinRingBind struct {
        v4, v6 afWinRingBind
        mu     sync.RWMutex
-       isOpen uint32
+       isOpen atomic.Uint32 // 0, 1, or 2
 }
 
 func NewDefaultBind() Bind { return NewWinRingBind() }
@@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
 }
 
 func (bind *WinRingBind) closeAndZero() {
-       atomic.StoreUint32(&bind.isOpen, 0)
+       bind.isOpen.Store(0)
        bind.v4.CloseAndZero()
        bind.v6.CloseAndZero()
 }
@@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
                        bind.closeAndZero()
                }
        }()
-       if atomic.LoadUint32(&bind.isOpen) != 0 {
+       if bind.isOpen.Load() != 0 {
                return nil, 0, ErrBindAlreadyOpen
        }
        var sa windows.Sockaddr
@@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
                        return nil, 0, err
                }
        }
-       atomic.StoreUint32(&bind.isOpen, 1)
+       bind.isOpen.Store(1)
        return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
 }
 
 func (bind *WinRingBind) Close() error {
        bind.mu.RLock()
-       if atomic.LoadUint32(&bind.isOpen) != 1 {
+       if bind.isOpen.Load() != 1 {
                bind.mu.RUnlock()
                return nil
        }
-       atomic.StoreUint32(&bind.isOpen, 2)
+       bind.isOpen.Store(2)
        windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
        windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
        windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
 //go:linkname procyield runtime.procyield
 func procyield(cycles uint32)
 
-func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
-       if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
+       if isOpen.Load() != 1 {
                return 0, nil, net.ErrClosed
        }
        bind.rx.mu.Lock()
@@ -359,7 +359,7 @@ retry:
        count = 0
        for tries := 0; count == 0 && tries < receiveSpins; tries++ {
                if tries > 0 {
-                       if atomic.LoadUint32(isOpen) != 1 {
+                       if isOpen.Load() != 1 {
                                return 0, nil, net.ErrClosed
                        }
                        procyield(1)
@@ -378,7 +378,7 @@ retry:
                if err != nil {
                        return 0, nil, err
                }
-               if atomic.LoadUint32(isOpen) != 1 {
+               if isOpen.Load() != 1 {
                        return 0, nil, net.ErrClosed
                }
                count = winrio.DequeueCompletion(bind.rx.cq, results[:])
@@ -395,7 +395,7 @@ retry:
        // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
        // attacker bandwidth, just like the rest of the receive path.
        if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
-               if atomic.LoadUint32(isOpen) != 1 {
+               if isOpen.Load() != 1 {
                        return 0, nil, net.ErrClosed
                }
                goto retry
@@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
        return bind.v6.Receive(buf, &bind.isOpen)
 }
 
-func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
-       if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
+       if isOpen.Load() != 1 {
                return net.ErrClosed
        }
        if len(buf) > bytesPerPacket {
@@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
                if err != nil {
                        return err
                }
-               if atomic.LoadUint32(isOpen) != 1 {
+               if isOpen.Load() != 1 {
                        return net.ErrClosed
                }
                count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
        bind.mu.RLock()
        defer bind.mu.RUnlock()
-       if atomic.LoadUint32(&bind.isOpen) != 1 {
+       if bind.isOpen.Load() != 1 {
                return net.ErrClosed
        }
        err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
        bind.mu.RLock()
        defer bind.mu.RUnlock()
-       if atomic.LoadUint32(&bind.isOpen) != 1 {
+       if bind.isOpen.Load() != 1 {
                return net.ErrClosed
        }
        err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
diff --git a/device/alignment_test.go b/device/alignment_test.go
deleted file mode 100644 (file)
index a918112..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
-       "reflect"
-       "testing"
-       "unsafe"
-)
-
-func checkAlignment(t *testing.T, name string, offset uintptr) {
-       t.Helper()
-       if offset%8 != 0 {
-               t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
-       }
-}
-
-// TestPeerAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestPeerAlignment(t *testing.T) {
-       var p Peer
-
-       typ := reflect.TypeOf(&p).Elem()
-       t.Logf("Peer type size: %d, with fields:", typ.Size())
-       for i := 0; i < typ.NumField(); i++ {
-               field := typ.Field(i)
-               t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
-                       field.Name,
-                       field.Offset,
-                       field.Type.Size(),
-                       field.Type.Align(),
-               )
-       }
-
-       checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
-       checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
-}
-
-// TestDeviceAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestDeviceAlignment(t *testing.T) {
-       var d Device
-
-       typ := reflect.TypeOf(&d).Elem()
-       t.Logf("Device type size: %d, with fields:", typ.Size())
-       for i := 0; i < typ.NumField(); i++ {
-               field := typ.Field(i)
-               t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
-                       field.Name,
-                       field.Offset,
-                       field.Type.Size(),
-                       field.Type.Align(),
-               )
-       }
-       checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
-}
index 3625608db58e35ec083618c8281a37bf5b5182bb..f96e27707ad8c2a2b4fb1622f4073580310f4fb0 100644 (file)
@@ -30,7 +30,7 @@ type Device struct {
                // will become the actual state; Up can fail.
                // The device can also change state multiple times between time of check and time of use.
                // Unsynchronized uses of state must therefore be advisory/best-effort only.
-               state uint32 // actually a deviceState, but typed uint32 for convenience
+               state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
                // stopping blocks until all inputs to Device have been closed.
                stopping sync.WaitGroup
                // mu protects state changes.
@@ -58,9 +58,8 @@ type Device struct {
                keyMap       map[NoisePublicKey]*Peer
        }
 
-       // Keep this 8-byte aligned
        rate struct {
-               underLoadUntil int64
+               underLoadUntil atomic.Int64
                limiter        ratelimiter.Ratelimiter
        }
 
@@ -82,7 +81,7 @@ type Device struct {
 
        tun struct {
                device tun.Device
-               mtu    int32
+               mtu    atomic.Int32
        }
 
        ipcMutex sync.RWMutex
@@ -94,10 +93,9 @@ type Device struct {
 // There are three states: down, up, closed.
 // Transitions:
 //
-//   down -----+
-//     ↑↓      ↓
-//     up -> closed
-//
+//     down -----+
+//       ↑↓      ↓
+//       up -> closed
 type deviceState uint32
 
 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -110,7 +108,7 @@ const (
 // deviceState returns device.state.state as a deviceState
 // See those docs for how to interpret this value.
 func (device *Device) deviceState() deviceState {
-       return deviceState(atomic.LoadUint32(&device.state.state))
+       return deviceState(device.state.state.Load())
 }
 
 // isClosed reports whether the device is closed (or is closing).
@@ -149,14 +147,14 @@ func (device *Device) changeState(want deviceState) (err error) {
        case old:
                return nil
        case deviceStateUp:
-               atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
+               device.state.state.Store(uint32(deviceStateUp))
                err = device.upLocked()
                if err == nil {
                        break
                }
                fallthrough // up failed; bring the device all the way back down
        case deviceStateDown:
-               atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
+               device.state.state.Store(uint32(deviceStateDown))
                errDown := device.downLocked()
                if err == nil {
                        err = errDown
@@ -182,7 +180,7 @@ func (device *Device) upLocked() error {
        device.peers.RLock()
        for _, peer := range device.peers.keyMap {
                peer.Start()
-               if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+               if peer.persistentKeepaliveInterval.Load() > 0 {
                        peer.SendKeepalive()
                }
        }
@@ -219,11 +217,11 @@ func (device *Device) IsUnderLoad() bool {
        now := time.Now()
        underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
        if underLoad {
-               atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
+               device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
                return true
        }
        // check if recently under load
-       return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
+       return device.rate.underLoadUntil.Load() > now.UnixNano()
 }
 
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -283,7 +281,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 
 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
        device := new(Device)
-       device.state.state = uint32(deviceStateDown)
+       device.state.state.Store(uint32(deviceStateDown))
        device.closed = make(chan struct{})
        device.log = logger
        device.net.bind = bind
@@ -293,7 +291,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
                device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
                mtu = DefaultMTU
        }
-       device.tun.mtu = int32(mtu)
+       device.tun.mtu.Store(int32(mtu))
        device.peers.keyMap = make(map[NoisePublicKey]*Peer)
        device.rate.limiter.Init()
        device.indexTable.Init()
@@ -359,7 +357,7 @@ func (device *Device) Close() {
        if device.isClosed() {
                return
        }
-       atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
+       device.state.state.Store(uint32(deviceStateClosed))
        device.log.Verbosef("Device closing")
 
        device.tun.device.Close()
index ab7236efae3626351bc6e778e1b713d0a7f9c600..8cffe08d830bfd3cfb2c8a37446f8cf9e81edb44 100644 (file)
@@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
 
        // Measure how long it takes to receive b.N packets,
        // starting when we receive the first packet.
-       var recv uint64
+       var recv atomic.Uint64
        var elapsed time.Duration
        var wg sync.WaitGroup
        wg.Add(1)
@@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
                var start time.Time
                for {
                        <-pair[0].tun.Inbound
-                       new := atomic.AddUint64(&recv, 1)
+                       new := recv.Add(1)
                        if new == 1 {
                                start = time.Now()
                        }
@@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
        ping := tuntest.Ping(pair[0].ip, pair[1].ip)
        pingc := pair[1].tun.Outbound
        var sent uint64
-       for atomic.LoadUint64(&recv) != uint64(b.N) {
+       for recv.Load() != uint64(b.N) {
                sent++
                pingc <- ping
        }
index 788c947f457c434185dfdc218b64a59839ff7188..206d7a906aa47149bdb52bd961b8549870807052 100644 (file)
@@ -10,7 +10,6 @@ import (
        "sync"
        "sync/atomic"
        "time"
-       "unsafe"
 
        "golang.zx2c4.com/wireguard/replay"
 )
@@ -23,7 +22,7 @@ import (
  */
 
 type Keypair struct {
-       sendNonce    uint64 // accessed atomically
+       sendNonce    atomic.Uint64
        send         cipher.AEAD
        receive      cipher.AEAD
        replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
        sync.RWMutex
        current  *Keypair
        previous *Keypair
-       next     *Keypair
-}
-
-func (kp *Keypairs) storeNext(next *Keypair) {
-       atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
-}
-
-func (kp *Keypairs) loadNext() *Keypair {
-       return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+       next     atomic.Pointer[Keypair]
 }
 
 func (kp *Keypairs) Current() *Keypair {
diff --git a/device/misc.go b/device/misc.go
deleted file mode 100644 (file)
index 4126704..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
-       "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
-       AtomicFalse = int32(iota)
-       AtomicTrue
-)
-
-type AtomicBool struct {
-       int32
-}
-
-func (a *AtomicBool) Get() bool {
-       return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
-       flag := AtomicFalse
-       if val {
-               flag = AtomicTrue
-       }
-       return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
-       flag := AtomicFalse
-       if val {
-               flag = AtomicTrue
-       }
-       atomic.StoreInt32(&a.int32, flag)
-}
index ffa04528b3aaeeab195874eb55adeb5639adc6fa..410926ea4f089573c11eecd735143eb9693f3eb6 100644 (file)
@@ -282,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        // lookup peer
 
        peer := device.LookupPeer(peerPK)
-       if peer == nil || !peer.isRunning.Get() {
+       if peer == nil || !peer.isRunning.Load() {
                return nil
        }
 
@@ -581,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error {
        defer keypairs.Unlock()
 
        previous := keypairs.previous
-       next := keypairs.loadNext()
+       next := keypairs.next.Load()
        current := keypairs.current
 
        if isInitiator {
                if next != nil {
-                       keypairs.storeNext(nil)
+                       keypairs.next.Store(nil)
                        keypairs.previous = next
                        device.DeleteKeypair(current)
                } else {
@@ -595,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error {
                device.DeleteKeypair(previous)
                keypairs.current = keypair
        } else {
-               keypairs.storeNext(keypair)
+               keypairs.next.Store(keypair)
                device.DeleteKeypair(next)
                keypairs.previous = nil
                device.DeleteKeypair(previous)
@@ -607,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error {
 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
        keypairs := &peer.keypairs
 
-       if keypairs.loadNext() != receivedKeypair {
+       if keypairs.next.Load() != receivedKeypair {
                return false
        }
        keypairs.Lock()
        defer keypairs.Unlock()
-       if keypairs.loadNext() != receivedKeypair {
+       if keypairs.next.Load() != receivedKeypair {
                return false
        }
        old := keypairs.previous
        keypairs.previous = keypairs.current
        peer.device.DeleteKeypair(old)
-       keypairs.current = keypairs.loadNext()
-       keypairs.storeNext(nil)
+       keypairs.current = keypairs.next.Load()
+       keypairs.next.Store(nil)
        return true
 }
index e2f23c6c24a6061bef16a25d03a2d06ff5b9e898..7c84efc0bb39bb2b97c3778bd1d28e0ddff1038a 100644 (file)
@@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
                t.Fatal("failed to derive keypair for peer 2", err)
        }
 
-       key1 := peer1.keypairs.loadNext()
+       key1 := peer1.keypairs.next.Load()
        key2 := peer2.keypairs.current
 
        // encrypting / decryption test
index 5bd52df794388ff71d0fbdd0646aaaec99995bfe..79feae7959861b689331df7fa504dfa2300ae265 100644 (file)
@@ -16,24 +16,16 @@ import (
 )
 
 type Peer struct {
-       isRunning    AtomicBool
-       sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
-       keypairs     Keypairs
-       handshake    Handshake
-       device       *Device
-       endpoint     conn.Endpoint
-       stopping     sync.WaitGroup // routines pending stop
-
-       // These fields are accessed with atomic operations, which must be
-       // 64-bit aligned even on 32-bit platforms. Go guarantees that an
-       // allocated struct will be 64-bit aligned. So we place
-       // atomically-accessed fields up front, so that they can share in
-       // this alignment before smaller fields throw it off.
-       stats struct {
-               txBytes           uint64 // bytes send to peer (endpoint)
-               rxBytes           uint64 // bytes received from peer
-               lastHandshakeNano int64  // nano seconds since epoch
-       }
+       isRunning         atomic.Bool
+       sync.RWMutex      // Mostly protects endpoint, but is generally taken whenever we modify peer
+       keypairs          Keypairs
+       handshake         Handshake
+       device            *Device
+       endpoint          conn.Endpoint
+       stopping          sync.WaitGroup // routines pending stop
+       txBytes           atomic.Uint64  // bytes send to peer (endpoint)
+       rxBytes           atomic.Uint64  // bytes received from peer
+       lastHandshakeNano atomic.Int64   // nano seconds since epoch
 
        disableRoaming bool
 
@@ -43,9 +35,9 @@ type Peer struct {
                newHandshake            *Timer
                zeroKeyMaterial         *Timer
                persistentKeepalive     *Timer
-               handshakeAttempts       uint32
-               needAnotherKeepalive    AtomicBool
-               sentLastMinuteHandshake AtomicBool
+               handshakeAttempts       atomic.Uint32
+               needAnotherKeepalive    atomic.Bool
+               sentLastMinuteHandshake atomic.Bool
        }
 
        state struct {
@@ -60,7 +52,7 @@ type Peer struct {
 
        cookieGenerator             CookieGenerator
        trieEntries                 list.List
-       persistentKeepaliveInterval uint32 // accessed atomically
+       persistentKeepaliveInterval atomic.Uint32
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -133,7 +125,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
 
        err := peer.device.net.bind.Send(buffer, peer.endpoint)
        if err == nil {
-               atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+               peer.txBytes.Add(uint64(len(buffer)))
        }
        return err
 }
@@ -174,7 +166,7 @@ func (peer *Peer) Start() {
        peer.state.Lock()
        defer peer.state.Unlock()
 
-       if peer.isRunning.Get() {
+       if peer.isRunning.Load() {
                return
        }
 
@@ -198,7 +190,7 @@ func (peer *Peer) Start() {
        go peer.RoutineSequentialSender()
        go peer.RoutineSequentialReceiver()
 
-       peer.isRunning.Set(true)
+       peer.isRunning.Store(true)
 }
 
 func (peer *Peer) ZeroAndFlushAll() {
@@ -210,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() {
        keypairs.Lock()
        device.DeleteKeypair(keypairs.previous)
        device.DeleteKeypair(keypairs.current)
-       device.DeleteKeypair(keypairs.loadNext())
+       device.DeleteKeypair(keypairs.next.Load())
        keypairs.previous = nil
        keypairs.current = nil
-       keypairs.storeNext(nil)
+       keypairs.next.Store(nil)
        keypairs.Unlock()
 
        // clear handshake state
@@ -238,11 +230,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
        keypairs := &peer.keypairs
        keypairs.Lock()
        if keypairs.current != nil {
-               atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
+               keypairs.current.sendNonce.Store(RejectAfterMessages)
        }
-       if keypairs.next != nil {
-               next := keypairs.loadNext()
-               atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
+       if next := keypairs.next.Load(); next != nil {
+               next.sendNonce.Store(RejectAfterMessages)
        }
        keypairs.Unlock()
 }
index f40477b6687417609ac0233cbf07253abbf2f339..9da0f799687b3a128e43255a2cd128da1901a26e 100644 (file)
@@ -14,7 +14,7 @@ type WaitPool struct {
        pool  sync.Pool
        cond  sync.Cond
        lock  sync.Mutex
-       count uint32
+       count atomic.Uint32
        max   uint32
 }
 
@@ -27,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
 func (p *WaitPool) Get() any {
        if p.max != 0 {
                p.lock.Lock()
-               for atomic.LoadUint32(&p.count) >= p.max {
+               for p.count.Load() >= p.max {
                        p.cond.Wait()
                }
-               atomic.AddUint32(&p.count, 1)
+               p.count.Add(1)
                p.lock.Unlock()
        }
        return p.pool.Get()
@@ -41,7 +41,7 @@ func (p *WaitPool) Put(x any) {
        if p.max == 0 {
                return
        }
-       atomic.AddUint32(&p.count, ^uint32(0))
+       p.count.Add(^uint32(0))
        p.cond.Signal()
 }
 
index 17e2298f93a2c265d0c9033a47a2426b5973b939..48a98b0f2fb032d734b8eb6399f9a3ed97cc8a14 100644 (file)
@@ -17,29 +17,31 @@ import (
 func TestWaitPool(t *testing.T) {
        t.Skip("Currently disabled")
        var wg sync.WaitGroup
-       trials := int32(100000)
+       var trials atomic.Int32
+       startTrials := int32(100000)
        if raceEnabled {
                // This test can be very slow with -race.
-               trials /= 10
+               startTrials /= 10
        }
+       trials.Store(startTrials)
        workers := runtime.NumCPU() + 2
        if workers-4 <= 0 {
                t.Skip("Not enough cores")
        }
        p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
        wg.Add(workers)
-       max := uint32(0)
+       var max atomic.Uint32
        updateMax := func() {
-               count := atomic.LoadUint32(&p.count)
+               count := p.count.Load()
                if count > p.max {
                        t.Errorf("count (%d) > max (%d)", count, p.max)
                }
                for {
-                       old := atomic.LoadUint32(&max)
+                       old := max.Load()
                        if count <= old {
                                break
                        }
-                       if atomic.CompareAndSwapUint32(&max, old, count) {
+                       if max.CompareAndSwap(old, count) {
                                break
                        }
                }
@@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
        for i := 0; i < workers; i++ {
                go func() {
                        defer wg.Done()
-                       for atomic.AddInt32(&trials, -1) > 0 {
+                       for trials.Add(-1) > 0 {
                                updateMax()
                                x := p.Get()
                                updateMax()
@@ -59,14 +61,15 @@ func TestWaitPool(t *testing.T) {
                }()
        }
        wg.Wait()
-       if max != p.max {
+       if max.Load() != p.max {
                t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
        }
 }
 
 func BenchmarkWaitPool(b *testing.B) {
        var wg sync.WaitGroup
-       trials := int32(b.N)
+       var trials atomic.Int32
+       trials.Store(int32(b.N))
        workers := runtime.NumCPU() + 2
        if workers-4 <= 0 {
                b.Skip("Not enough cores")
@@ -77,7 +80,7 @@ func BenchmarkWaitPool(b *testing.B) {
        for i := 0; i < workers; i++ {
                go func() {
                        defer wg.Done()
-                       for atomic.AddInt32(&trials, -1) > 0 {
+                       for trials.Add(-1) > 0 {
                                x := p.Get()
                                time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
                                p.Put(x)
index cc3449801ca1dc6478db2f7fb35829ff075fc1f5..4dbf1e8270bb47e626af59ace27e2b2d0cdcd7e0 100644 (file)
@@ -11,7 +11,6 @@ import (
        "errors"
        "net"
        "sync"
-       "sync/atomic"
        "time"
 
        "golang.org/x/crypto/chacha20poly1305"
@@ -52,12 +51,12 @@ func (elem *QueueInboundElement) clearPointers() {
  * NOTE: Not thread safe, but called by sequential receiver!
  */
 func (peer *Peer) keepKeyFreshReceiving() {
-       if peer.timers.sentLastMinuteHandshake.Get() {
+       if peer.timers.sentLastMinuteHandshake.Load() {
                return
        }
        keypair := peer.keypairs.Current()
        if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
-               peer.timers.sentLastMinuteHandshake.Set(true)
+               peer.timers.sentLastMinuteHandshake.Store(true)
                peer.SendHandshakeInitiation(false)
        }
 }
@@ -163,7 +162,7 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
                        elem.Lock()
 
                        // add to decryption queues
-                       if peer.isRunning.Get() {
+                       if peer.isRunning.Load() {
                                peer.queue.inbound.c <- elem
                                device.queue.decryption.c <- elem
                                buffer = device.GetMessageBuffer()
@@ -268,7 +267,7 @@ func (device *Device) RoutineHandshake(id int) {
 
                        // consume reply
 
-                       if peer := entry.peer; peer.isRunning.Get() {
+                       if peer := entry.peer; peer.isRunning.Load() {
                                device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
                                if !peer.cookieGenerator.ConsumeReply(&reply) {
                                        device.log.Verbosef("Could not decrypt invalid cookie response")
@@ -341,7 +340,7 @@ func (device *Device) RoutineHandshake(id int) {
                        peer.SetEndpointFromPacket(elem.endpoint)
 
                        device.log.Verbosef("%v - Received handshake initiation", peer)
-                       atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+                       peer.rxBytes.Add(uint64(len(elem.packet)))
 
                        peer.SendHandshakeResponse()
 
@@ -369,7 +368,7 @@ func (device *Device) RoutineHandshake(id int) {
                        peer.SetEndpointFromPacket(elem.endpoint)
 
                        device.log.Verbosef("%v - Received handshake response", peer)
-                       atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+                       peer.rxBytes.Add(uint64(len(elem.packet)))
 
                        // update timers
 
@@ -426,7 +425,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                peer.keepKeyFreshReceiving()
                peer.timersAnyAuthenticatedPacketTraversal()
                peer.timersAnyAuthenticatedPacketReceived()
-               atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
+               peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
 
                if len(elem.packet) == 0 {
                        device.log.Verbosef("%v - Receiving keepalive packet", peer)
index 0a7135f903c041428c5bb90f158e8c042c2d7ebd..471c51c3b744b0cd9ef3a0ff03b5121826400464 100644 (file)
@@ -12,7 +12,6 @@ import (
        "net"
        "os"
        "sync"
-       "sync/atomic"
        "time"
 
        "golang.org/x/crypto/chacha20poly1305"
@@ -76,7 +75,7 @@ func (elem *QueueOutboundElement) clearPointers() {
 /* Queues a keepalive if no packets are queued for peer
  */
 func (peer *Peer) SendKeepalive() {
-       if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
+       if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
                elem := peer.device.NewOutboundElement()
                select {
                case peer.queue.staged <- elem:
@@ -91,7 +90,7 @@ func (peer *Peer) SendKeepalive() {
 
 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
        if !isRetry {
-               atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+               peer.timers.handshakeAttempts.Store(0)
        }
 
        peer.handshake.mutex.RLock()
@@ -193,7 +192,7 @@ func (peer *Peer) keepKeyFreshSending() {
        if keypair == nil {
                return
        }
-       nonce := atomic.LoadUint64(&keypair.sendNonce)
+       nonce := keypair.sendNonce.Load()
        if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
                peer.SendHandshakeInitiation(false)
        }
@@ -269,7 +268,7 @@ func (device *Device) RoutineReadFromTUN() {
                if peer == nil {
                        continue
                }
-               if peer.isRunning.Get() {
+               if peer.isRunning.Load() {
                        peer.StagePacket(elem)
                        elem = nil
                        peer.SendStagedPackets()
@@ -300,7 +299,7 @@ top:
        }
 
        keypair := peer.keypairs.Current()
-       if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+       if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
                peer.SendHandshakeInitiation(false)
                return
        }
@@ -309,9 +308,9 @@ top:
                select {
                case elem := <-peer.queue.staged:
                        elem.peer = peer
-                       elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
+                       elem.nonce = keypair.sendNonce.Add(1) - 1
                        if elem.nonce >= RejectAfterMessages {
-                               atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
+                               keypair.sendNonce.Store(RejectAfterMessages)
                                peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
                                goto top
                        }
@@ -320,7 +319,7 @@ top:
                        elem.Lock()
 
                        // add to parallel and sequential queue
-                       if peer.isRunning.Get() {
+                       if peer.isRunning.Load() {
                                peer.queue.outbound.c <- elem
                                peer.device.queue.encryption.c <- elem
                        } else {
@@ -385,7 +384,7 @@ func (device *Device) RoutineEncryption(id int) {
                binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
 
                // pad content to multiple of 16
-               paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
+               paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
                elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
 
                // encrypt content and release to consumer
@@ -419,7 +418,7 @@ func (peer *Peer) RoutineSequentialSender() {
                        return
                }
                elem.Lock()
-               if !peer.isRunning.Get() {
+               if !peer.isRunning.Load() {
                        // peer has been stopped; return re-usable elems to the shared pool.
                        // This is an optimization only. It is possible for the peer to be stopped
                        // immediately after this check, in which case, elem will get processed.
index 4d2d0f88a921da8a9512147f440f9c765c6c8339..c8ef8877aade46bc4c23f6aab4222b60bcdecf57 100644 (file)
@@ -9,7 +9,6 @@ package device
 
 import (
        "sync"
-       "sync/atomic"
        "time"
        _ "unsafe"
 )
@@ -74,11 +73,11 @@ func (timer *Timer) IsPending() bool {
 }
 
 func (peer *Peer) timersActive() bool {
-       return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
+       return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
 }
 
 func expiredRetransmitHandshake(peer *Peer) {
-       if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
+       if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
                peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
 
                if peer.timersActive() {
@@ -97,8 +96,8 @@ func expiredRetransmitHandshake(peer *Peer) {
                        peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
                }
        } else {
-               atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
-               peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+               peer.timers.handshakeAttempts.Add(1)
+               peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
 
                /* We clear the endpoint address src address, in case this is the cause of trouble. */
                peer.Lock()
@@ -113,8 +112,8 @@ func expiredRetransmitHandshake(peer *Peer) {
 
 func expiredSendKeepalive(peer *Peer) {
        peer.SendKeepalive()
-       if peer.timers.needAnotherKeepalive.Get() {
-               peer.timers.needAnotherKeepalive.Set(false)
+       if peer.timers.needAnotherKeepalive.Load() {
+               peer.timers.needAnotherKeepalive.Store(false)
                if peer.timersActive() {
                        peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
                }
@@ -138,7 +137,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
 }
 
 func expiredPersistentKeepalive(peer *Peer) {
-       if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+       if peer.persistentKeepaliveInterval.Load() > 0 {
                peer.SendKeepalive()
        }
 }
@@ -156,7 +155,7 @@ func (peer *Peer) timersDataReceived() {
                if !peer.timers.sendKeepalive.IsPending() {
                        peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
                } else {
-                       peer.timers.needAnotherKeepalive.Set(true)
+                       peer.timers.needAnotherKeepalive.Store(true)
                }
        }
 }
@@ -187,9 +186,9 @@ func (peer *Peer) timersHandshakeComplete() {
        if peer.timersActive() {
                peer.timers.retransmitHandshake.Del()
        }
-       atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
-       peer.timers.sentLastMinuteHandshake.Set(false)
-       atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+       peer.timers.handshakeAttempts.Store(0)
+       peer.timers.sentLastMinuteHandshake.Store(false)
+       peer.lastHandshakeNano.Store(time.Now().UnixNano())
 }
 
 /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -201,7 +200,7 @@ func (peer *Peer) timersSessionDerived() {
 
 /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
 func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
-       keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+       keepalive := peer.persistentKeepaliveInterval.Load()
        if keepalive > 0 && peer.timersActive() {
                peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
        }
@@ -216,9 +215,9 @@ func (peer *Peer) timersInit() {
 }
 
 func (peer *Peer) timersStart() {
-       atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
-       peer.timers.sentLastMinuteHandshake.Set(false)
-       peer.timers.needAnotherKeepalive.Set(false)
+       peer.timers.handshakeAttempts.Store(0)
+       peer.timers.sentLastMinuteHandshake.Store(false)
+       peer.timers.needAnotherKeepalive.Store(false)
 }
 
 func (peer *Peer) timersStop() {
index 4af9548218f981064913c55507185667c13f8098..d94bde1e1ab8e79fc58aa64b02ac319c86a0e2f1 100644 (file)
@@ -7,7 +7,6 @@ package device
 
 import (
        "fmt"
-       "sync/atomic"
 
        "golang.zx2c4.com/wireguard/tun"
 )
@@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
                                tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
                                mtu = MaxContentSize
                        }
-                       old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
+                       old := device.tun.mtu.Swap(int32(mtu))
                        if int(old) != mtu {
                                device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
                        }
index 30dd97e8b41679ea25d650c1baf0782ba6d79f32..550a0323c0469934f233f645ce0c31ea7b506a5e 100644 (file)
@@ -16,7 +16,6 @@ import (
        "strconv"
        "strings"
        "sync"
-       "sync/atomic"
        "time"
 
        "golang.zx2c4.com/wireguard/ipc"
@@ -112,15 +111,15 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
                                        sendf("endpoint=%s", peer.endpoint.DstToString())
                                }
 
-                               nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+                               nano := peer.lastHandshakeNano.Load()
                                secs := nano / time.Second.Nanoseconds()
                                nano %= time.Second.Nanoseconds()
 
                                sendf("last_handshake_time_sec=%d", secs)
                                sendf("last_handshake_time_nsec=%d", nano)
-                               sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
-                               sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
-                               sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
+                               sendf("tx_bytes=%d", peer.txBytes.Load())
+                               sendf("rx_bytes=%d", peer.rxBytes.Load())
+                               sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
 
                                device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
                                        sendf("allowed_ip=%s", prefix.String())
@@ -358,7 +357,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
                        return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
                }
 
-               old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
+               old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
 
                // Send immediate keepalive if we're turning it on and before it wasn't on.
                peer.pkaOn = old == 0 && secs != 0
diff --git a/go.mod b/go.mod
index d0d58b30a8c21bbc49463000877d3c591ccb3885..c180d1b4982f460f31aac7dbef089d4182c69a42 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
 module golang.zx2c4.com/wireguard
 
-go 1.18
+go 1.19
 
 require (
        golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
index c5dd48a188265097a62fd270bbeaad9f7641944f..ec9b8d44eb231178c51502b68178e74a45353db8 100644 (file)
@@ -54,7 +54,7 @@ type file struct {
        handle        windows.Handle
        wg            sync.WaitGroup
        wgLock        sync.RWMutex
-       closing       uint32 // used as atomic boolean
+       closing       atomic.Bool
        socket        bool
        readDeadline  deadlineHandler
        writeDeadline deadlineHandler
@@ -65,7 +65,7 @@ type deadlineHandler struct {
        channel     timeoutChan
        channelLock sync.RWMutex
        timer       *time.Timer
-       timedout    uint32 // used as atomic boolean
+       timedout    atomic.Bool
 }
 
 // makeFile makes a new file from an existing file handle
@@ -89,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) {
 func (f *file) closeHandle() {
        f.wgLock.Lock()
        // Atomically set that we are closing, releasing the resources only once.
-       if atomic.SwapUint32(&f.closing, 1) == 0 {
+       if f.closing.Swap(true) == false {
                f.wgLock.Unlock()
                // cancel all IO and wait for it to complete
                windows.CancelIoEx(f.handle, nil)
@@ -112,7 +112,7 @@ func (f *file) Close() error {
 // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
 func (f *file) prepareIo() (*ioOperation, error) {
        f.wgLock.RLock()
-       if atomic.LoadUint32(&f.closing) == 1 {
+       if f.closing.Load() {
                f.wgLock.RUnlock()
                return nil, os.ErrClosed
        }
@@ -144,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
                return int(bytes), err
        }
 
-       if atomic.LoadUint32(&f.closing) == 1 {
+       if f.closing.Load() {
                windows.CancelIoEx(f.handle, &c.o)
        }
 
@@ -160,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
        case r = <-c.ch:
                err = r.err
                if err == windows.ERROR_OPERATION_ABORTED {
-                       if atomic.LoadUint32(&f.closing) == 1 {
+                       if f.closing.Load() {
                                err = os.ErrClosed
                        }
                } else if err != nil && f.socket {
@@ -192,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) {
        }
        defer f.wg.Done()
 
-       if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
+       if f.readDeadline.timedout.Load() {
                return 0, os.ErrDeadlineExceeded
        }
 
@@ -219,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) {
        }
        defer f.wg.Done()
 
-       if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
+       if f.writeDeadline.timedout.Load() {
                return 0, os.ErrDeadlineExceeded
        }
 
@@ -256,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
                }
                d.timer = nil
        }
-       atomic.StoreUint32(&d.timedout, 0)
+       d.timedout.Store(false)
 
        select {
        case <-d.channel:
@@ -271,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
        }
 
        timeoutIO := func() {
-               atomic.StoreUint32(&d.timedout, 1)
+               d.timedout.Store(true)
                close(d.channel)
        }
 
index 6db5ea31e03315dcc503a0bb70882fb2c660c306..92cc1ee0f910c74aeab712b9774c0319127f49ba 100644 (file)
@@ -29,7 +29,7 @@ type pipe struct {
 
 type messageBytePipe struct {
        pipe
-       writeClosed int32
+       writeClosed atomic.Bool
        readEOF     bool
 }
 
@@ -51,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error {
 
 // CloseWrite closes the write side of a message pipe in byte mode.
 func (f *messageBytePipe) CloseWrite() error {
-       if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
+       if !f.writeClosed.CompareAndSwap(false, true) {
                return io.ErrClosedPipe
        }
        err := f.file.Flush()
        if err != nil {
-               atomic.StoreInt32(&f.writeClosed, 0)
+               f.writeClosed.Store(false)
                return err
        }
        _, err = f.file.Write(nil)
        if err != nil {
-               atomic.StoreInt32(&f.writeClosed, 0)
+               f.writeClosed.Store(false)
                return err
        }
        return nil
@@ -70,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error {
 // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
 // they are used to implement CloseWrite.
 func (f *messageBytePipe) Write(b []byte) (int, error) {
-       if atomic.LoadInt32(&f.writeClosed) != 0 {
+       if f.writeClosed.Load() {
                return 0, io.ErrClosedPipe
        }
        if len(b) == 0 {
index d0571508a261dd72e4a6d3522896ddeb33ab1127..6782fd4cdb4ac1824fcba562bc63ac8bfdfcb5d6 100644 (file)
@@ -26,10 +26,10 @@ const (
 )
 
 type rateJuggler struct {
-       current       uint64
-       nextByteCount uint64
-       nextStartTime int64
-       changing      int32
+       current       atomic.Uint64
+       nextByteCount atomic.Uint64
+       nextStartTime atomic.Int64
+       changing      atomic.Bool
 }
 
 type NativeTun struct {
@@ -42,7 +42,7 @@ type NativeTun struct {
        events    chan Event
        running   sync.WaitGroup
        closeOnce sync.Once
-       close     int32
+       close     atomic.Bool
        forcedMTU int
 }
 
@@ -57,18 +57,14 @@ func procyield(cycles uint32)
 //go:linkname nanotime runtime.nanotime
 func nanotime() int64
 
-//
 // CreateTUN creates a Wintun interface with the given name. Should a Wintun
 // interface with the same name exist, it is reused.
-//
 func CreateTUN(ifname string, mtu int) (Device, error) {
        return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
 }
 
-//
 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
 // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
        wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
        if err != nil {
@@ -113,7 +109,7 @@ func (tun *NativeTun) Events() chan Event {
 func (tun *NativeTun) Close() error {
        var err error
        tun.closeOnce.Do(func() {
-               atomic.StoreInt32(&tun.close, 1)
+               tun.close.Store(true)
                windows.SetEvent(tun.readWait)
                tun.running.Wait()
                tun.session.End()
@@ -144,13 +140,13 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
        tun.running.Add(1)
        defer tun.running.Done()
 retry:
-       if atomic.LoadInt32(&tun.close) == 1 {
+       if tun.close.Load() {
                return 0, os.ErrClosed
        }
        start := nanotime()
-       shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
+       shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
        for {
-               if atomic.LoadInt32(&tun.close) == 1 {
+               if tun.close.Load() {
                        return 0, os.ErrClosed
                }
                packet, err := tun.session.ReceivePacket()
@@ -184,7 +180,7 @@ func (tun *NativeTun) Flush() error {
 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        tun.running.Add(1)
        defer tun.running.Done()
-       if atomic.LoadInt32(&tun.close) == 1 {
+       if tun.close.Load() {
                return 0, os.ErrClosed
        }
 
@@ -210,7 +206,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 func (tun *NativeTun) LUID() uint64 {
        tun.running.Add(1)
        defer tun.running.Done()
-       if atomic.LoadInt32(&tun.close) == 1 {
+       if tun.close.Load() {
                return 0
        }
        return tun.wt.LUID()
@@ -223,15 +219,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
 
 func (rate *rateJuggler) update(packetLen uint64) {
        now := nanotime()
-       total := atomic.AddUint64(&rate.nextByteCount, packetLen)
-       period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+       total := rate.nextByteCount.Add(packetLen)
+       period := uint64(now - rate.nextStartTime.Load())
        if period >= rateMeasurementGranularity {
-               if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+               if !rate.changing.CompareAndSwap(false, true) {
                        return
                }
-               atomic.StoreInt64(&rate.nextStartTime, now)
-               atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
-               atomic.StoreUint64(&rate.nextByteCount, 0)
-               atomic.StoreInt32(&rate.changing, 0)
+               rate.nextStartTime.Store(now)
+               rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+               rate.nextByteCount.Store(0)
+               rate.changing.Store(false)
        }
 }