import (
"errors"
+ "fmt"
"sync"
"time"
"golang.zx2c4.com/wireguard/tai64n"
)
+type handshakeState int
+
+// TODO(crawshaw): add commentary describing each state and the transitions
const (
- HandshakeZeroed = iota
- HandshakeInitiationCreated
- HandshakeInitiationConsumed
- HandshakeResponseCreated
- HandshakeResponseConsumed
+ handshakeZeroed = handshakeState(iota)
+ handshakeInitiationCreated
+ handshakeInitiationConsumed
+ handshakeResponseCreated
+ handshakeResponseConsumed
)
+func (hs handshakeState) String() string {
+ switch hs {
+ case handshakeZeroed:
+ return "handshakeZeroed"
+ case handshakeInitiationCreated:
+ return "handshakeInitiationCreated"
+ case handshakeInitiationConsumed:
+ return "handshakeInitiationConsumed"
+ case handshakeResponseCreated:
+ return "handshakeResponseCreated"
+ case handshakeResponseConsumed:
+ return "handshakeResponseConsumed"
+ default:
+ return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+ }
+}
+
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
}
type Handshake struct {
- state int
+ state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
- h.state = HandshakeZeroed
+ h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
- handshake.state = HandshakeInitiationCreated
+ handshake.state = handshakeInitiationCreated
return &msg, nil
}
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
- handshake.state = HandshakeInitiationConsumed
+ handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationConsumed {
+ if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
handshake.mixHash(msg.Empty[:])
}()
- handshake.state = HandshakeResponseCreated
+ handshake.state = handshakeResponseCreated
return &msg, nil
}
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
- if handshake.state != HandshakeInitiationCreated {
+ if handshake.state != handshakeInitiationCreated {
return false
}
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
- handshake.state = HandshakeResponseConsumed
+ handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
- if handshake.state == HandshakeResponseConsumed {
+ if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
nil,
)
isInitiator = true
- } else if handshake.state == HandshakeResponseCreated {
+ } else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
)
isInitiator = false
} else {
- return errors.New("invalid state for keypair derivation")
+ return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
- peer.handshake.state = HandshakeZeroed
+ peer.handshake.state = handshakeZeroed
// create AEAD instances