type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
- isOpen uint32
+ isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
}
func (bind *WinRingBind) closeAndZero() {
- atomic.StoreUint32(&bind.isOpen, 0)
+ bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
bind.closeAndZero()
}
}()
- if atomic.LoadUint32(&bind.isOpen) != 0 {
+ if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
return nil, 0, err
}
}
- atomic.StoreUint32(&bind.isOpen, 1)
+ bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
- atomic.StoreUint32(&bind.isOpen, 2)
+ bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
-func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
- if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
if err != nil {
return 0, nil, err
}
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
return bind.v6.Receive(buf, &bind.isOpen)
}
-func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
- if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
+ if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
if err != nil {
return err
}
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
+++ /dev/null
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "reflect"
- "testing"
- "unsafe"
-)
-
-func checkAlignment(t *testing.T, name string, offset uintptr) {
- t.Helper()
- if offset%8 != 0 {
- t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
- }
-}
-
-// TestPeerAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestPeerAlignment(t *testing.T) {
- var p Peer
-
- typ := reflect.TypeOf(&p).Elem()
- t.Logf("Peer type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
-
- checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
- checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
-}
-
-// TestDeviceAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestDeviceAlignment(t *testing.T) {
- var d Device
-
- typ := reflect.TypeOf(&d).Elem()
- t.Logf("Device type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
- checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
-}
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
- state uint32 // actually a deviceState, but typed uint32 for convenience
+ state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
keyMap map[NoisePublicKey]*Peer
}
- // Keep this 8-byte aligned
rate struct {
- underLoadUntil int64
+ underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}
tun struct {
device tun.Device
- mtu int32
+ mtu atomic.Int32
}
ipcMutex sync.RWMutex
// There are three states: down, up, closed.
// Transitions:
//
-// down -----+
-// ↑↓ ↓
-// up -> closed
-//
+// down -----+
+// ↑↓ ↓
+// up -> closed
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
- return deviceState(atomic.LoadUint32(&device.state.state))
+ return deviceState(device.state.state.Load())
}
// isClosed reports whether the device is closed (or is closing).
case old:
return nil
case deviceStateUp:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
+ device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
+ device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
- atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
- return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
+ return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
- device.state.state = uint32(deviceStateDown)
+ device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
- device.tun.mtu = int32(mtu)
+ device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
if device.isClosed() {
return
}
- atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
+ device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
- var recv uint64
+ var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
var start time.Time
for {
<-pair[0].tun.Inbound
- new := atomic.AddUint64(&recv, 1)
+ new := recv.Add(1)
if new == 1 {
start = time.Now()
}
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
- for atomic.LoadUint64(&recv) != uint64(b.N) {
+ for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
"sync"
"sync/atomic"
"time"
- "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
*/
type Keypair struct {
- sendNonce uint64 // accessed atomically
+ sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.Filter
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
-}
-
-func (kp *Keypairs) storeNext(next *Keypair) {
- atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
-}
-
-func (kp *Keypairs) loadNext() *Keypair {
- return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+ next atomic.Pointer[Keypair]
}
func (kp *Keypairs) Current() *Keypair {
+++ /dev/null
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
- AtomicFalse = int32(iota)
- AtomicTrue
-)
-
-type AtomicBool struct {
- int32
-}
-
-func (a *AtomicBool) Get() bool {
- return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- atomic.StoreInt32(&a.int32, flag)
-}
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil || !peer.isRunning.Get() {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.loadNext()
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.storeNext(keypair)
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.loadNext()
- keypairs.storeNext(nil)
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.loadNext()
+ key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current
// encrypting / decryption test
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint conn.Endpoint
- stopping sync.WaitGroup // routines pending stop
-
- // These fields are accessed with atomic operations, which must be
- // 64-bit aligned even on 32-bit platforms. Go guarantees that an
- // allocated struct will be 64-bit aligned. So we place
- // atomically-accessed fields up front, so that they can share in
- // this alignment before smaller fields throw it off.
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
- }
+ isRunning atomic.Bool
+ sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ endpoint conn.Endpoint
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
disableRoaming bool
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
state struct {
cookieGenerator CookieGenerator
trieEntries list.List
- persistentKeepaliveInterval uint32 // accessed atomically
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ peer.txBytes.Add(uint64(len(buffer)))
}
return err
}
peer.state.Lock()
defer peer.state.Unlock()
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
- peer.isRunning.Set(true)
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.loadNext())
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- next := keypairs.loadNext()
- atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
pool sync.Pool
cond sync.Cond
lock sync.Mutex
- count uint32
+ count atomic.Uint32
max uint32
}
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
- for atomic.LoadUint32(&p.count) >= p.max {
+ for p.count.Load() >= p.max {
p.cond.Wait()
}
- atomic.AddUint32(&p.count, 1)
+ p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
if p.max == 0 {
return
}
- atomic.AddUint32(&p.count, ^uint32(0))
+ p.count.Add(^uint32(0))
p.cond.Signal()
}
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
- trials := int32(100000)
+ var trials atomic.Int32
+ startTrials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
- trials /= 10
+ startTrials /= 10
}
+ trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
- max := uint32(0)
+ var max atomic.Uint32
updateMax := func() {
- count := atomic.LoadUint32(&p.count)
+ count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
- old := atomic.LoadUint32(&max)
+ old := max.Load()
if count <= old {
break
}
- if atomic.CompareAndSwapUint32(&max, old, count) {
+ if max.CompareAndSwap(old, count) {
break
}
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
updateMax()
x := p.Get()
updateMax()
}()
}
wg.Wait()
- if max != p.max {
+ if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
- trials := int32(b.N)
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
"errors"
"net"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
- if peer.timers.sentLastMinuteHandshake.Get() {
+ if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
- peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
elem.Lock()
// add to decryption queues
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer()
// consume reply
- if peer := entry.peer; peer.isRunning.Get() {
+ if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
+ peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
"net"
"os"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() {
- if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
+ if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
select {
case peer.queue.staged <- elem:
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
if keypair == nil {
return
}
- nonce := atomic.LoadUint64(&keypair.sendNonce)
+ nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
if peer == nil {
continue
}
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
peer.StagePacket(elem)
elem = nil
peer.SendStagedPackets()
}
keypair := peer.keypairs.Current()
- if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
select {
case elem := <-peer.queue.staged:
elem.peer = peer
- elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
+ elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
- atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
+ keypair.sendNonce.Store(RejectAfterMessages)
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
goto top
}
elem.Lock()
// add to parallel and sequential queue
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
peer.queue.outbound.c <- elem
peer.device.queue.encryption.c <- elem
} else {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
return
}
elem.Lock()
- if !peer.isRunning.Get() {
+ if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
import (
"sync"
- "sync/atomic"
"time"
_ "unsafe"
)
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
+ return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
- if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
+ if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
- atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
- peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+ peer.timers.handshakeAttempts.Add(1)
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
- if peer.timers.needAnotherKeepalive.Get() {
- peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timers.needAnotherKeepalive.Load() {
+ peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
}
func expiredPersistentKeepalive(peer *Peer) {
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
- peer.timers.needAnotherKeepalive.Set(true)
+ peer.timers.needAnotherKeepalive.Store(true)
}
}
}
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.lastHandshakeNano.Store(time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+ keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
}
func (peer *Peer) timersStart() {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- peer.timers.needAnotherKeepalive.Set(false)
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.timers.needAnotherKeepalive.Store(false)
}
func (peer *Peer) timersStop() {
import (
"fmt"
- "sync/atomic"
"golang.zx2c4.com/wireguard/tun"
)
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
- old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
+ old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
"strconv"
"strings"
"sync"
- "sync/atomic"
"time"
"golang.zx2c4.com/wireguard/ipc"
sendf("endpoint=%s", peer.endpoint.DstToString())
}
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
- sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
- sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
- old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
peer.pkaOn = old == 0 && secs != 0
module golang.zx2c4.com/wireguard
-go 1.18
+go 1.19
require (
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
- closing uint32 // used as atomic boolean
+ closing atomic.Bool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
- timedout uint32 // used as atomic boolean
+ timedout atomic.Bool
}
// makeFile makes a new file from an existing file handle
func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
- if atomic.SwapUint32(&f.closing, 1) == 0 {
+ if f.closing.Swap(true) == false {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil)
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
f.wgLock.RUnlock()
return nil, os.ErrClosed
}
return int(bytes), err
}
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o)
}
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
err = os.ErrClosed
}
} else if err != nil && f.socket {
}
defer f.wg.Done()
- if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
+ if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
}
defer f.wg.Done()
- if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
+ if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
}
d.timer = nil
}
- atomic.StoreUint32(&d.timedout, 0)
+ d.timedout.Store(false)
select {
case <-d.channel:
}
timeoutIO := func() {
- atomic.StoreUint32(&d.timedout, 1)
+ d.timedout.Store(true)
close(d.channel)
}
type messageBytePipe struct {
pipe
- writeClosed int32
+ writeClosed atomic.Bool
readEOF bool
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
- if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
+ if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
- atomic.StoreInt32(&f.writeClosed, 0)
+ f.writeClosed.Store(false)
return err
}
_, err = f.file.Write(nil)
if err != nil {
- atomic.StoreInt32(&f.writeClosed, 0)
+ f.writeClosed.Store(false)
return err
}
return nil
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
- if atomic.LoadInt32(&f.writeClosed) != 0 {
+ if f.writeClosed.Load() {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
events chan Event
running sync.WaitGroup
closeOnce sync.Once
- close int32
+ close atomic.Bool
forcedMTU int
}
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
func (tun *NativeTun) Close() error {
var err error
tun.closeOnce.Do(func() {
- atomic.StoreInt32(&tun.close, 1)
+ tun.close.Store(true)
windows.SetEvent(tun.readWait)
tun.running.Wait()
tun.session.End()
tun.running.Add(1)
defer tun.running.Done()
retry:
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
packet, err := tun.session.ReceivePacket()
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0
}
return tun.wt.LUID()
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}