]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use atomic access for unlocked keypair.next
authorJason A. Donenfeld <Jason@zx2c4.com>
Sat, 2 May 2020 07:30:23 +0000 (01:30 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 2 May 2020 07:56:48 +0000 (01:56 -0600)
Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.

Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/keypair.go
device/noise-protocol.go
device/noise_test.go
device/peer.go

index 9c78fa912add274cd8dc0367ee651759c124d8e0..d70c7f43a41d822a07f65b778da82348417ed0af 100644 (file)
@@ -8,7 +8,9 @@ package device
 import (
        "crypto/cipher"
        "sync"
+       "sync/atomic"
        "time"
+       "unsafe"
 
        "golang.zx2c4.com/wireguard/replay"
 )
@@ -38,6 +40,14 @@ type Keypairs struct {
        next     *Keypair
 }
 
+func (kp *Keypairs) storeNext(next *Keypair) {
+       atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
+}
+
+func (kp *Keypairs) loadNext() *Keypair {
+       return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+}
+
 func (kp *Keypairs) Current() *Keypair {
        kp.RLock()
        defer kp.RUnlock()
index a848c4755eb4d895b8bcc95ac11f198bcbcbdb4c..e6f676c17b60d3b3ffa12f317b273ce8c111d2da 100644 (file)
@@ -14,6 +14,7 @@ import (
        "golang.org/x/crypto/blake2s"
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/crypto/poly1305"
+
        "golang.zx2c4.com/wireguard/tai64n"
 )
 
@@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
        defer keypairs.Unlock()
 
        previous := keypairs.previous
-       next := keypairs.next
+       next := keypairs.loadNext()
        current := keypairs.current
 
        if isInitiator {
                if next != nil {
-                       keypairs.next = nil
+                       keypairs.storeNext(nil)
                        keypairs.previous = next
                        device.DeleteKeypair(current)
                } else {
@@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
                device.DeleteKeypair(previous)
                keypairs.current = keypair
        } else {
-               keypairs.next = keypair
+               keypairs.storeNext(keypair)
                device.DeleteKeypair(next)
                keypairs.previous = nil
                device.DeleteKeypair(previous)
@@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
 
 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
        keypairs := &peer.keypairs
-       if keypairs.next != receivedKeypair {
+
+       if keypairs.loadNext() != receivedKeypair {
                return false
        }
        keypairs.Lock()
        defer keypairs.Unlock()
-       if keypairs.next != receivedKeypair {
+       if keypairs.loadNext() != receivedKeypair {
                return false
        }
        old := keypairs.previous
        keypairs.previous = keypairs.current
        peer.device.DeleteKeypair(old)
-       keypairs.current = keypairs.next
-       keypairs.next = nil
+       keypairs.current = keypairs.loadNext()
+       keypairs.storeNext(nil)
        return true
 }
index 6ba3f2e6246e39a1293c69e5cf598fd35ad7a359..b5d58454ba56c885575c2349442ae6584b36b7e2 100644 (file)
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
                t.Fatal("failed to derive keypair for peer 2", err)
        }
 
-       key1 := peer1.keypairs.next
+       key1 := peer1.keypairs.loadNext()
        key2 := peer2.keypairs.current
 
        // encrypting / decryption test
index 79d4981812bfc2ea82a8a6fc2752e090ad613838..899591b873519911da5017e154a0e7db76cf559b 100644 (file)
@@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
        keypairs.Lock()
        device.DeleteKeypair(keypairs.previous)
        device.DeleteKeypair(keypairs.current)
-       device.DeleteKeypair(keypairs.next)
+       device.DeleteKeypair(keypairs.loadNext())
        keypairs.previous = nil
        keypairs.current = nil
-       keypairs.next = nil
+       keypairs.storeNext(nil)
        keypairs.Unlock()
 
        // clear handshake state
@@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
                keypairs.current.sendNonce = RejectAfterMessages
        }
        if keypairs.next != nil {
-               keypairs.next.sendNonce = RejectAfterMessages
+               keypairs.loadNext().sendNonce = RejectAfterMessages
        }
        keypairs.Unlock()
 }