]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: give handshake state a type
authorDavid Crawshaw <crawshaw@tailscale.com>
Thu, 5 Mar 2020 01:58:39 +0000 (20:58 -0500)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 2 May 2020 07:46:42 +0000 (01:46 -0600)
And unexport handshake constants.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
device/noise-protocol.go

index ee327d2eb551783d8669fdd1dae7180987fc173d..6dcc8313242674ef200a14d7abc11fb3d57d481b 100644 (file)
@@ -7,6 +7,7 @@ package device
 
 import (
        "errors"
+       "fmt"
        "sync"
        "time"
 
@@ -16,14 +17,34 @@ import (
        "golang.zx2c4.com/wireguard/tai64n"
 )
 
+type handshakeState int
+
+// TODO(crawshaw): add commentary describing each state and the transitions
 const (
-       HandshakeZeroed = iota
-       HandshakeInitiationCreated
-       HandshakeInitiationConsumed
-       HandshakeResponseCreated
-       HandshakeResponseConsumed
+       handshakeZeroed = handshakeState(iota)
+       handshakeInitiationCreated
+       handshakeInitiationConsumed
+       handshakeResponseCreated
+       handshakeResponseConsumed
 )
 
+func (hs handshakeState) String() string {
+       switch hs {
+       case handshakeZeroed:
+               return "handshakeZeroed"
+       case handshakeInitiationCreated:
+               return "handshakeInitiationCreated"
+       case handshakeInitiationConsumed:
+               return "handshakeInitiationConsumed"
+       case handshakeResponseCreated:
+               return "handshakeResponseCreated"
+       case handshakeResponseConsumed:
+               return "handshakeResponseConsumed"
+       default:
+               return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+       }
+}
+
 const (
        NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
        WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -95,7 +116,7 @@ type MessageCookieReply struct {
 }
 
 type Handshake struct {
-       state                     int
+       state                     handshakeState
        mutex                     sync.RWMutex
        hash                      [blake2s.Size]byte       // hash value
        chainKey                  [blake2s.Size]byte       // chain key
@@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
        setZero(h.chainKey[:])
        setZero(h.hash[:])
        h.localIndex = 0
-       h.state = HandshakeZeroed
+       h.state = handshakeZeroed
 }
 
 func (h *Handshake) mixHash(data []byte) {
@@ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
        handshake.localIndex = msg.Sender
 
        handshake.mixHash(msg.Timestamp[:])
-       handshake.state = HandshakeInitiationCreated
+       handshake.state = handshakeInitiationCreated
        return &msg, nil
 }
 
@@ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
        if now.After(handshake.lastInitiationConsumption) {
                handshake.lastInitiationConsumption = now
        }
-       handshake.state = HandshakeInitiationConsumed
+       handshake.state = handshakeInitiationConsumed
 
        handshake.mutex.Unlock()
 
@@ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
        handshake.mutex.Lock()
        defer handshake.mutex.Unlock()
 
-       if handshake.state != HandshakeInitiationConsumed {
+       if handshake.state != handshakeInitiationConsumed {
                return nil, errors.New("handshake initiation must be consumed first")
        }
 
@@ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
                handshake.mixHash(msg.Empty[:])
        }()
 
-       handshake.state = HandshakeResponseCreated
+       handshake.state = handshakeResponseCreated
 
        return &msg, nil
 }
@@ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
                handshake.mutex.RLock()
                defer handshake.mutex.RUnlock()
 
-               if handshake.state != HandshakeInitiationCreated {
+               if handshake.state != handshakeInitiationCreated {
                        return false
                }
 
@@ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
        handshake.hash = hash
        handshake.chainKey = chainKey
        handshake.remoteIndex = msg.Sender
-       handshake.state = HandshakeResponseConsumed
+       handshake.state = handshakeResponseConsumed
 
        handshake.mutex.Unlock()
 
@@ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error {
        var sendKey [chacha20poly1305.KeySize]byte
        var recvKey [chacha20poly1305.KeySize]byte
 
-       if handshake.state == HandshakeResponseConsumed {
+       if handshake.state == handshakeResponseConsumed {
                KDF2(
                        &sendKey,
                        &recvKey,
@@ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error {
                        nil,
                )
                isInitiator = true
-       } else if handshake.state == HandshakeResponseCreated {
+       } else if handshake.state == handshakeResponseCreated {
                KDF2(
                        &recvKey,
                        &sendKey,
@@ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error {
                )
                isInitiator = false
        } else {
-               return errors.New("invalid state for keypair derivation")
+               return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
        }
 
        // zero handshake
@@ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error {
        setZero(handshake.chainKey[:])
        setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
        setZero(handshake.localEphemeral[:])
-       peer.handshake.state = HandshakeZeroed
+       peer.handshake.state = handshakeZeroed
 
        // create AEAD instances