import (
"crypto/cipher"
"sync"
+ "sync/atomic"
"time"
+ "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
next *Keypair
}
+func (kp *Keypairs) storeNext(next *Keypair) {
+ atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
+}
+
+func (kp *Keypairs) loadNext() *Keypair {
+ return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+}
+
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.next != receivedKeypair {
+
+ if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.loadNext()
+ keypairs.storeNext(nil)
return true
}
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.next
+ key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}