]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Beginning work noise handshake
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 23 Jun 2017 11:41:59 +0000 (13:41 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 23 Jun 2017 11:41:59 +0000 (13:41 +0200)
src/device.go
src/kdf_test.go [new file with mode: 0644]
src/noise_helpers.go [new file with mode: 0644]
src/noise_protocol.go [new file with mode: 0644]
src/noise_test.go [new file with mode: 0644]
src/noise_types.go [moved from src/noise.go with 83% similarity]
src/tai64.go [new file with mode: 0644]

index d03057dd2a28cd4e71217714875939042967d4f6..9f1daa6295a15befa03469b4893c26a828fa2233 100644 (file)
@@ -1,12 +1,17 @@
 package main
 
 import (
+       "math/rand"
        "sync"
 )
 
+/* TODO: Locking may be a little broad here
+ */
+
 type Device struct {
        mutex        sync.RWMutex
        peers        map[NoisePublicKey]*Peer
+       sessions     map[uint32]*Handshake
        privateKey   NoisePrivateKey
        publicKey    NoisePublicKey
        fwMark       uint32
@@ -14,6 +19,19 @@ type Device struct {
        routingTable RoutingTable
 }
 
+func (dev *Device) NewID(h *Handshake) uint32 {
+       dev.mutex.Lock()
+       defer dev.mutex.Unlock()
+       for {
+               id := rand.Uint32()
+               _, ok := dev.sessions[id]
+               if !ok {
+                       dev.sessions[id] = h
+                       return id
+               }
+       }
+}
+
 func (dev *Device) RemovePeer(key NoisePublicKey) {
        dev.mutex.Lock()
        defer dev.mutex.Unlock()
diff --git a/src/kdf_test.go b/src/kdf_test.go
new file mode 100644 (file)
index 0000000..0cce81d
--- /dev/null
@@ -0,0 +1,76 @@
+package main
+
+import (
+       "encoding/hex"
+       "testing"
+)
+
+type KDFTest struct {
+       key   string
+       input string
+       t0    string
+       t1    string
+       t2    string
+}
+
+func assertEquals(t *testing.T, a string, b string) {
+       if a != b {
+               t.Fatal("expected", a, "=", b)
+       }
+}
+
+func TestKDF(t *testing.T) {
+       tests := []KDFTest{
+               {
+                       key:   "746573742d6b6579",
+                       input: "746573742d696e707574",
+                       t0:    "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
+                       t1:    "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
+                       t2:    "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
+               },
+               {
+                       key:   "776972656775617264",
+                       input: "776972656775617264",
+                       t0:    "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
+                       t1:    "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
+                       t2:    "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
+               },
+               {
+                       key:   "",
+                       input: "",
+                       t0:    "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
+                       t1:    "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
+                       t2:    "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
+               },
+       }
+
+       for _, test := range tests {
+               key, _ := hex.DecodeString(test.key)
+               input, _ := hex.DecodeString(test.input)
+               t0, t1, t2 := KDF3(key, input)
+               t0s := hex.EncodeToString(t0[:])
+               t1s := hex.EncodeToString(t1[:])
+               t2s := hex.EncodeToString(t2[:])
+               assertEquals(t, t0s, test.t0)
+               assertEquals(t, t1s, test.t1)
+               assertEquals(t, t2s, test.t2)
+       }
+
+       for _, test := range tests {
+               key, _ := hex.DecodeString(test.key)
+               input, _ := hex.DecodeString(test.input)
+               t0, t1 := KDF2(key, input)
+               t0s := hex.EncodeToString(t0[:])
+               t1s := hex.EncodeToString(t1[:])
+               assertEquals(t, t0s, test.t0)
+               assertEquals(t, t1s, test.t1)
+       }
+
+       for _, test := range tests {
+               key, _ := hex.DecodeString(test.key)
+               input, _ := hex.DecodeString(test.input)
+               t0 := KDF1(key, input)
+               t0s := hex.EncodeToString(t0[:])
+               assertEquals(t, t0s, test.t0)
+       }
+}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
new file mode 100644 (file)
index 0000000..df25011
--- /dev/null
@@ -0,0 +1,86 @@
+package main
+
+import (
+       "crypto/hmac"
+       "crypto/rand"
+       "golang.org/x/crypto/blake2s"
+       "golang.org/x/crypto/curve25519"
+       "hash"
+)
+
+/* KDF related functions.
+ * HMAC-based Key Derivation Function (HKDF)
+ * https://tools.ietf.org/html/rfc5869
+ */
+
+func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) {
+       mac := hmac.New(func() hash.Hash {
+               h, _ := blake2s.New256(nil)
+               return h
+       }, key)
+       mac.Write(input)
+       mac.Sum(sum[:0])
+}
+
+func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) {
+       HMAC(&t0, key, input)
+       HMAC(&t0, t0[:], []byte{0x1})
+       return
+}
+
+func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) {
+       var prk [blake2s.Size]byte
+       HMAC(&prk, key, input)
+       HMAC(&t0, prk[:], []byte{0x1})
+       HMAC(&t1, prk[:], append(t0[:], 0x2))
+       return
+}
+
+func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) {
+       var prk [blake2s.Size]byte
+       HMAC(&prk, key, input)
+       HMAC(&t0, prk[:], []byte{0x1})
+       HMAC(&t1, prk[:], append(t0[:], 0x2))
+       HMAC(&t2, prk[:], append(t1[:], 0x3))
+       return
+}
+
+/*
+ *
+ */
+
+func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+       return KDF1(c[:], data)
+}
+
+func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+       return blake2s.Sum256(append(h[:], data...))
+}
+
+/* Curve25519 wrappers
+ *
+ * TODO: Rethink this
+ */
+
+func newPrivateKey() (sk NoisePrivateKey, err error) {
+       // clamping: https://cr.yp.to/ecdh.html
+       _, err = rand.Read(sk[:])
+       sk[0] &= 248
+       sk[31] &= 127
+       sk[31] |= 64
+       return
+}
+
+func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
+       apk := (*[NoisePublicKeySize]byte)(&pk)
+       ask := (*[NoisePrivateKeySize]byte)(sk)
+       curve25519.ScalarBaseMult(apk, ask)
+       return
+}
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+       apk := (*[NoisePublicKeySize]byte)(&pk)
+       ask := (*[NoisePrivateKeySize]byte)(sk)
+       curve25519.ScalarMult(&ss, apk, ask)
+       return ss
+}
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
new file mode 100644 (file)
index 0000000..e7c8774
--- /dev/null
@@ -0,0 +1,179 @@
+package main
+
+import (
+       "errors"
+       "golang.org/x/crypto/blake2s"
+       "golang.org/x/crypto/chacha20poly1305"
+       "golang.org/x/crypto/poly1305"
+       "sync"
+)
+
+const (
+       HandshakeInitialCreated = iota
+       HandshakeInitialConsumed
+       HandshakeResponseCreated
+)
+
+const (
+       NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
+       WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
+       WGLabelMAC1       = "mac1----"
+       WGLabelCookie     = "cookie--"
+)
+
+const (
+       MessageInitalType         = 1
+       MessageResponseType       = 2
+       MessageCookieResponseType = 3
+       MessageTransportType      = 4
+)
+
+type MessageInital struct {
+       Type      uint32
+       Sender    uint32
+       Ephemeral NoisePublicKey
+       Static    [NoisePublicKeySize + poly1305.TagSize]byte
+       Timestamp [TAI64NSize + poly1305.TagSize]byte
+       Mac1      [blake2s.Size128]byte
+       Mac2      [blake2s.Size128]byte
+}
+
+type MessageResponse struct {
+       Type      uint32
+       Sender    uint32
+       Reciever  uint32
+       Ephemeral NoisePublicKey
+       Empty     [poly1305.TagSize]byte
+       Mac1      [blake2s.Size128]byte
+       Mac2      [blake2s.Size128]byte
+}
+
+type MessageTransport struct {
+       Type     uint32
+       Reciever uint32
+       Counter  uint64
+       Content  []byte
+}
+
+type Handshake struct {
+       lock         sync.Mutex
+       state        int
+       chainKey     [blake2s.Size]byte // chain key
+       hash         [blake2s.Size]byte // hash value
+       staticStatic NoisePublicKey     // precomputed DH(S_i, S_r)
+       ephemeral    NoisePrivateKey    // ephemeral secret key
+       remoteIndex  uint32             // index for sending
+       device       *Device
+       peer         *Peer
+}
+
+var (
+       ZeroNonce      [chacha20poly1305.NonceSize]byte
+       InitalChainKey [blake2s.Size]byte
+       InitalHash     [blake2s.Size]byte
+)
+
+func init() {
+       InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
+       InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
+}
+
+func (h *Handshake) Precompute() {
+       h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
+}
+
+func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
+
+}
+
+func (h *Handshake) addHash(data []byte) {
+       h.hash = addToHash(h.hash, data)
+}
+
+func (h *Handshake) addChain(data []byte) {
+       h.chainKey = addToChainKey(h.chainKey, data)
+}
+
+func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
+       h.lock.Lock()
+       defer h.lock.Unlock()
+
+       // reset handshake
+
+       var err error
+       h.ephemeral, err = newPrivateKey()
+       if err != nil {
+               return nil, err
+       }
+       h.chainKey = InitalChainKey
+       h.hash = addToHash(InitalHash, h.device.publicKey[:])
+
+       // create ephemeral key
+
+       var msg MessageInital
+       msg.Type = MessageInitalType
+       msg.Sender = h.device.NewID(h)
+       msg.Ephemeral = h.ephemeral.publicKey()
+       h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
+       h.hash = addToHash(h.hash, msg.Ephemeral[:])
+
+       // encrypt long-term "identity key"
+
+       func() {
+               var key [chacha20poly1305.KeySize]byte
+               ss := h.ephemeral.sharedSecret(h.peer.publicKey)
+               h.chainKey, key = KDF2(h.chainKey[:], ss[:])
+               aead, _ := chacha20poly1305.New(key[:])
+               aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
+       }()
+       h.addHash(msg.Static[:])
+
+       // encrypt timestamp
+
+       timestamp := Timestamp()
+       func() {
+               var key [chacha20poly1305.KeySize]byte
+               h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
+               aead, _ := chacha20poly1305.New(key[:])
+               aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
+       }()
+       h.addHash(msg.Timestamp[:])
+       h.state = HandshakeInitialCreated
+       return &msg, nil
+}
+
+func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
+       if msg.Type != MessageInitalType {
+               panic(errors.New("bug: invalid inital message type"))
+       }
+
+       hash := addToHash(InitalHash, h.device.publicKey[:])
+       chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
+       hash = addToHash(hash, msg.Ephemeral[:])
+
+       //
+
+       ephemeral, err := newPrivateKey()
+       if err != nil {
+               return err
+       }
+
+       // update handshake state
+
+       h.lock.Lock()
+       defer h.lock.Unlock()
+
+       h.hash = hash
+       h.chainKey = chainKey
+       h.remoteIndex = msg.Sender
+       h.ephemeral = ephemeral
+       h.state = HandshakeInitialConsumed
+
+       return nil
+
+}
+
+func (h *Handshake) CreateMessageResponse() []byte {
+
+       return nil
+}
diff --git a/src/noise_test.go b/src/noise_test.go
new file mode 100644 (file)
index 0000000..b3ea54f
--- /dev/null
@@ -0,0 +1,38 @@
+package main
+
+import (
+       "testing"
+)
+
+func TestHandshake(t *testing.T) {
+       var dev1 Device
+       var dev2 Device
+
+       var err error
+
+       dev1.privateKey, err = newPrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       dev2.privateKey, err = newPrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       var peer1 Peer
+       var peer2 Peer
+
+       peer1.publicKey = dev1.privateKey.publicKey()
+       peer2.publicKey = dev2.privateKey.publicKey()
+
+       var handshake1 Handshake
+       var handshake2 Handshake
+
+       handshake1.device = &dev1
+       handshake2.device = &dev2
+
+       handshake1.peer = &peer2
+       handshake2.peer = &peer1
+
+}
similarity index 83%
rename from src/noise.go
rename to src/noise_types.go
index 5508f9a526f138659eefe3f27baff9283cb81921..6dae6b2af0d8c96eddfb5f1137b3e33e0bdaa99e 100644 (file)
@@ -12,10 +12,8 @@ const (
 )
 
 type (
-       NoisePublicKey    [NoisePublicKeySize]byte
-       NoisePrivateKey   [NoisePrivateKeySize]byte
-       NoiseSymmetricKey [NoiseSymmetricKeySize]byte
-       NoiseNonce        uint64 // padded to 12-bytes
+       NoisePublicKey  [NoisePublicKeySize]byte
+       NoisePrivateKey [NoisePrivateKeySize]byte
 )
 
 func loadExactHex(dst []byte, src string) error {
diff --git a/src/tai64.go b/src/tai64.go
new file mode 100644 (file)
index 0000000..d0d1432
--- /dev/null
@@ -0,0 +1,23 @@
+package main
+
+import (
+       "encoding/binary"
+       "time"
+)
+
+const (
+       TAI64NBase = uint64(4611686018427387914)
+       TAI64NSize = 12
+)
+
+type TAI64N [TAI64NSize]byte
+
+func Timestamp() TAI64N {
+       var tai64n TAI64N
+       now := time.Now()
+       secs := TAI64NBase + uint64(now.Unix())
+       nano := uint32(now.UnixNano())
+       binary.BigEndian.PutUint64(tai64n[:], secs)
+       binary.BigEndian.PutUint32(tai64n[8:], nano)
+       return tai64n
+}