]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Implemented MAC1/2 calculation
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 27 Jun 2017 15:33:06 +0000 (17:33 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 27 Jun 2017 15:33:06 +0000 (17:33 +0200)
13 files changed:
src/config.go
src/constants.go [new file with mode: 0644]
src/cookie.go [deleted file]
src/device.go
src/keypair.go
src/logger.go [new file with mode: 0644]
src/macs_device.go [new file with mode: 0644]
src/macs_peer.go [new file with mode: 0644]
src/macs_test.go [new file with mode: 0644]
src/noise_protocol.go
src/noise_test.go
src/peer.go
src/send.go

index 88651944133f2054314728cdb3ee4c8e68aa956f..cb7e9efe8a09cddd59f5c2a2619802b1536aa041 100644 (file)
@@ -81,7 +81,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        }
 
                case "listen_port":
-                       _, err := fmt.Sscanf(value, "%ud", &device.listenPort)
+                       _, err := fmt.Sscanf(value, "%ud", &device.address.Port)
                        if err != nil {
                                return &IPCError{Code: ipcErrorInvalidPort}
                        }
diff --git a/src/constants.go b/src/constants.go
new file mode 100644 (file)
index 0000000..dc95379
--- /dev/null
@@ -0,0 +1,16 @@
+package main
+
+import (
+       "time"
+)
+
+const (
+       RekeyAfterMessage  = (1 << 64) - (1 << 16) - 1
+       RekeyAfterTime     = time.Second * 120
+       RekeyAttemptTime   = time.Second * 90
+       RekeyTimeout       = time.Second * 5
+       RejectAfterTime    = time.Second * 180
+       RejectAfterMessage = (1 << 64) - (1 << 4) - 1
+       KeepaliveTimeout   = time.Second * 10
+       CookieRefreshTime  = time.Second * 2
+)
diff --git a/src/cookie.go b/src/cookie.go
deleted file mode 100644 (file)
index a6987a2..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-package main
-
-import (
-       "errors"
-       "golang.org/x/crypto/blake2s"
-)
-
-func CalculateCookie(peer *Peer, msg []byte) {
-       size := len(msg)
-
-       if size < blake2s.Size128*2 {
-               panic(errors.New("bug: message too short"))
-       }
-
-       startMac1 := size - (blake2s.Size128 * 2)
-       startMac2 := size - blake2s.Size128
-
-       mac1 := msg[startMac1 : startMac1+blake2s.Size128]
-       mac2 := msg[startMac2 : startMac2+blake2s.Size128]
-
-       peer.mutex.RLock()
-       defer peer.mutex.RUnlock()
-
-       // set mac1
-
-       func() {
-               mac, _ := blake2s.New128(peer.macKey[:])
-               mac.Write(msg[:startMac1])
-               mac.Sum(mac1[:0])
-       }()
-
-       // set mac2
-
-       if peer.cookie != nil {
-               mac, _ := blake2s.New128(peer.cookie)
-               mac.Write(msg[:startMac2])
-               mac.Sum(mac2[:0])
-       }
-}
index 4b8cda01fd80731db2965147263d10ea7d7d8f00..b3484c5fe61e58d1e53041b0f9329e4ce763cdf6 100644 (file)
@@ -1,25 +1,24 @@
 package main
 
 import (
-       "log"
        "net"
        "sync"
 )
 
 type Device struct {
        mtu               int
-       source            *net.UDPAddr // UDP source address
+       fwMark            uint32
+       address           *net.UDPAddr // UDP source address
        conn              *net.UDPConn // UDP "connection"
        mutex             sync.RWMutex
-       peers             map[NoisePublicKey]*Peer
-       indices           IndexTable
        privateKey        NoisePrivateKey
        publicKey         NoisePublicKey
-       fwMark            uint32
-       listenPort        uint16
        routingTable      RoutingTable
-       logger            log.Logger
+       indices           IndexTable
+       log               *Logger
        queueWorkOutbound chan *OutboundWorkQueueElement
+       peers             map[NoisePublicKey]*Peer
+       mac               MacStateDevice
 }
 
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
@@ -30,8 +29,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
 
        device.privateKey = sk
        device.publicKey = sk.publicKey()
+       device.mac.Init(device.publicKey)
 
-       // do precomputations
+       // do DH precomputations
 
        for _, peer := range device.peers {
                h := &peer.handshake
@@ -45,9 +45,9 @@ func (device *Device) Init() {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
+       device.log = NewLogger()
        device.peers = make(map[NoisePublicKey]*Peer)
        device.indices.Init()
-       device.listenPort = 0
        device.routingTable.Reset()
 }
 
index 53e123ff3b73d97c0f39217f9738242350048aad..0b029cebf6fe57567aa76a07c126616b2c203383 100644 (file)
@@ -3,13 +3,16 @@ package main
 import (
        "crypto/cipher"
        "sync"
+       "time"
 )
 
 type KeyPair struct {
-       recv      cipher.AEAD
-       recvNonce uint64
-       send      cipher.AEAD
-       sendNonce uint64
+       recv        cipher.AEAD
+       recvNonce   uint64
+       send        cipher.AEAD
+       sendNonce   uint64
+       isInitiator bool
+       created     time.Time
 }
 
 type KeyPairs struct {
diff --git a/src/logger.go b/src/logger.go
new file mode 100644 (file)
index 0000000..117fe5b
--- /dev/null
@@ -0,0 +1,35 @@
+package main
+
+import (
+       "log"
+       "os"
+)
+
+const (
+       LogLevelError = iota
+       LogLevelInfo
+       LogLevelDebug
+)
+
+type Logger struct {
+       Debug *log.Logger
+       Info  *log.Logger
+       Error *log.Logger
+}
+
+func NewLogger() *Logger {
+       logger := new(Logger)
+       logger.Debug = log.New(os.Stdout,
+               "DEBUG: ",
+               log.Ldate|log.Ltime|log.Lshortfile,
+       )
+       logger.Info = log.New(os.Stdout,
+               "INFO: ",
+               log.Ldate|log.Ltime|log.Lshortfile,
+       )
+       logger.Error = log.New(os.Stdout,
+               "ERROR: ",
+               log.Ldate|log.Ltime|log.Lshortfile,
+       )
+       return logger
+}
diff --git a/src/macs_device.go b/src/macs_device.go
new file mode 100644 (file)
index 0000000..730c361
--- /dev/null
@@ -0,0 +1,161 @@
+package main
+
+import (
+       "crypto/cipher"
+       "crypto/hmac"
+       "crypto/rand"
+       "github.com/aead/chacha20poly1305" // Needed for XChaCha20Poly1305, TODO:
+       "golang.org/x/crypto/blake2s"
+       "net"
+       "sync"
+       "time"
+)
+
+type MacStateDevice struct {
+       mutex     sync.RWMutex
+       refreshed time.Time
+       secret    [blake2s.Size]byte
+       keyMac1   [blake2s.Size]byte
+       xaead     cipher.AEAD
+}
+
+func (state *MacStateDevice) Init(pk NoisePublicKey) {
+       state.mutex.Lock()
+       defer state.mutex.Unlock()
+       func() {
+               hsh, _ := blake2s.New256(nil)
+               hsh.Write([]byte(WGLabelMAC1))
+               hsh.Write(pk[:])
+               hsh.Sum(state.keyMac1[:0])
+       }()
+       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
+       state.refreshed = time.Time{} // never
+}
+
+func (state *MacStateDevice) CheckMAC1(msg []byte) bool {
+       size := len(msg)
+       startMac1 := size - (blake2s.Size128 * 2)
+       startMac2 := size - blake2s.Size128
+
+       var mac1 [blake2s.Size128]byte
+       func() {
+               mac, _ := blake2s.New128(state.keyMac1[:])
+               mac.Write(msg[:startMac1])
+               mac.Sum(mac1[:0])
+       }()
+
+       return hmac.Equal(mac1[:], msg[startMac1:startMac2])
+}
+
+func (state *MacStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool {
+       state.mutex.RLock()
+       defer state.mutex.RUnlock()
+
+       if time.Now().Sub(state.refreshed) > CookieRefreshTime {
+               return false
+       }
+
+       // derive cookie key
+
+       var cookie [blake2s.Size128]byte
+       func() {
+               port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
+               mac, _ := blake2s.New128(state.secret[:])
+               mac.Write(addr.IP)
+               mac.Write(port[:])
+               mac.Sum(cookie[:0])
+       }()
+
+       // calculate mac of packet
+
+       start := len(msg) - blake2s.Size128
+
+       var mac2 [blake2s.Size128]byte
+       func() {
+               mac, _ := blake2s.New128(cookie[:])
+               mac.Write(msg[:start])
+               mac.Sum(mac2[:0])
+       }()
+
+       return hmac.Equal(mac2[:], msg[start:])
+}
+
+func (device *Device) CreateMessageCookieReply(msg []byte, receiver uint32, addr *net.UDPAddr) (*MessageCookieReply, error) {
+       state := &device.mac
+       state.mutex.RLock()
+
+       // refresh cookie secret
+
+       if time.Now().Sub(state.refreshed) > CookieRefreshTime {
+               state.mutex.RUnlock()
+               state.mutex.Lock()
+               _, err := rand.Read(state.secret[:])
+               if err != nil {
+                       state.mutex.Unlock()
+                       return nil, err
+               }
+               state.refreshed = time.Now()
+               state.mutex.Unlock()
+               state.mutex.RLock()
+       }
+
+       // derive cookie key
+
+       var cookie [blake2s.Size128]byte
+       func() {
+               port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
+               mac, _ := blake2s.New128(state.secret[:])
+               mac.Write(addr.IP)
+               mac.Write(port[:])
+               mac.Sum(cookie[:0])
+       }()
+
+       // encrypt cookie
+
+       size := len(msg)
+
+       startMac1 := size - (blake2s.Size128 * 2)
+       startMac2 := size - blake2s.Size128
+
+       M := msg[startMac1:startMac2]
+
+       reply := new(MessageCookieReply)
+       reply.Type = MessageCookieReplyType
+       reply.Receiver = receiver
+       _, err := rand.Read(reply.Nonce[:])
+       if err != nil {
+               state.mutex.RUnlock()
+               return nil, err
+       }
+       state.xaead.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], M)
+       state.mutex.RUnlock()
+       return reply, nil
+}
+
+func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
+
+       if msg.Type != MessageCookieReplyType {
+               return false
+       }
+
+       // lookup peer
+
+       lookup := device.indices.Lookup(msg.Receiver)
+       if lookup.handshake == nil {
+               return false
+       }
+
+       // decrypt and store cookie
+
+       var cookie [blake2s.Size128]byte
+       state := &lookup.peer.mac
+       state.mutex.Lock()
+       defer state.mutex.Unlock()
+       _, err := state.xaead.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], state.lastMac1[:])
+       if err != nil {
+               return false
+       }
+       state.cookieSet = time.Now()
+       state.cookie = cookie
+       return true
+}
diff --git a/src/macs_peer.go b/src/macs_peer.go
new file mode 100644 (file)
index 0000000..d70c8f3
--- /dev/null
@@ -0,0 +1,73 @@
+package main
+
+import (
+       "crypto/cipher"
+       "errors"
+       "github.com/aead/chacha20poly1305" // Needed for XChaCha20Poly1305, TODO:
+       "golang.org/x/crypto/blake2s"
+       "sync"
+       "time"
+)
+
+type MacStatePeer struct {
+       mutex     sync.RWMutex
+       cookieSet time.Time
+       cookie    [blake2s.Size128]byte
+       lastMac1  [blake2s.Size128]byte
+       keyMac1   [blake2s.Size]byte
+       xaead     cipher.AEAD
+}
+
+func (state *MacStatePeer) Init(pk NoisePublicKey) {
+       state.mutex.Lock()
+       defer state.mutex.Unlock()
+       func() {
+               hsh, _ := blake2s.New256(nil)
+               hsh.Write([]byte(WGLabelMAC1))
+               hsh.Write(pk[:])
+               hsh.Sum(state.keyMac1[:0])
+       }()
+       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
+       state.cookieSet = time.Time{} // never
+}
+
+func (state *MacStatePeer) AddMacs(msg []byte) {
+       size := len(msg)
+
+       if size < blake2s.Size128*2 {
+               panic(errors.New("bug: message too short"))
+       }
+
+       startMac1 := size - (blake2s.Size128 * 2)
+       startMac2 := size - blake2s.Size128
+
+       mac1 := msg[startMac1 : startMac1+blake2s.Size128]
+       mac2 := msg[startMac2 : startMac2+blake2s.Size128]
+
+       state.mutex.Lock()
+       defer state.mutex.Unlock()
+
+       // set mac1
+
+       func() {
+               mac, _ := blake2s.New128(state.keyMac1[:])
+               mac.Write(msg[:startMac1])
+               mac.Sum(state.lastMac1[:0])
+       }()
+       copy(mac1, state.lastMac1[:])
+
+       // set mac2
+
+       if state.cookieSet.IsZero() {
+               return
+       }
+       if time.Now().Sub(state.cookieSet) > CookieRefreshTime {
+               state.cookieSet = time.Time{}
+               return
+       }
+       func() {
+               mac, _ := blake2s.New128(state.cookie[:])
+               mac.Write(msg[:startMac2])
+               mac.Sum(mac2[:0])
+       }()
+}
diff --git a/src/macs_test.go b/src/macs_test.go
new file mode 100644 (file)
index 0000000..a67ccfb
--- /dev/null
@@ -0,0 +1,113 @@
+package main
+
+import (
+       "bytes"
+       "net"
+       "testing"
+       "testing/quick"
+)
+
+func TestMAC1(t *testing.T) {
+       dev1 := newDevice(t)
+       dev2 := newDevice(t)
+
+       peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
+       peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
+
+       assertEqual(t, peer1.mac.keyMac1[:], dev1.mac.keyMac1[:])
+       assertEqual(t, peer2.mac.keyMac1[:], dev2.mac.keyMac1[:])
+
+       msg1 := make([]byte, 256)
+       copy(msg1, []byte("some content"))
+       peer1.mac.AddMacs(msg1)
+       if dev1.mac.CheckMAC1(msg1) == false {
+               t.Fatal("failed to verify mac1")
+       }
+}
+
+func TestMACs(t *testing.T) {
+       assertion := func(
+               addr net.UDPAddr,
+               addrInvalid net.UDPAddr,
+               sk1 NoisePrivateKey,
+               sk2 NoisePrivateKey,
+               msg []byte,
+               receiver uint32,
+       ) bool {
+               var device1 Device
+               device1.Init()
+               device1.SetPrivateKey(sk1)
+
+               var device2 Device
+               device2.Init()
+               device2.SetPrivateKey(sk2)
+
+               peer1 := device2.NewPeer(device1.privateKey.publicKey())
+               peer2 := device1.NewPeer(device2.privateKey.publicKey())
+
+               if addr.Port < 0 {
+                       return true
+               }
+               addr.Port &= 0xffff
+
+               if len(msg) < 32 {
+                       return true
+               }
+               if bytes.Compare(peer1.mac.keyMac1[:], device1.mac.keyMac1[:]) != 0 {
+                       return false
+               }
+               if bytes.Compare(peer2.mac.keyMac1[:], device2.mac.keyMac1[:]) != 0 {
+                       return false
+               }
+
+               device2.indices.Insert(receiver, IndexTableEntry{
+                       peer:      peer1,
+                       handshake: &peer1.handshake,
+               })
+
+               // test just MAC1
+
+               peer1.mac.AddMacs(msg)
+               if device1.mac.CheckMAC1(msg) == false {
+                       return false
+               }
+
+               // exchange cookie reply
+
+               cr, err := device1.CreateMessageCookieReply(msg, receiver, &addr)
+               if err != nil {
+                       return false
+               }
+
+               if device2.ConsumeMessageCookieReply(cr) == false {
+                       return false
+               }
+
+               // test MAC1 + MAC2
+
+               peer1.mac.AddMacs(msg)
+               if device1.mac.CheckMAC1(msg) == false {
+                       return false
+               }
+               if device1.mac.CheckMAC2(msg, &addr) == false {
+                       return false
+               }
+
+               // test invalid
+
+               if device1.mac.CheckMAC2(msg, &addrInvalid) {
+                       return false
+               }
+               msg[5] ^= 1
+               if device1.mac.CheckMAC1(msg) {
+                       return false
+               }
+
+               return true
+       }
+
+       err := quick.Check(assertion, nil)
+       if err != nil {
+               t.Error(err)
+       }
+}
index bf1db9b39639897a9916095bec0220e5861676c0..e237dbe61873816ecabccf18dfaf6a07341bd358 100644 (file)
@@ -24,15 +24,20 @@ const (
 )
 
 const (
-       MessageInitiationType     = 1
-       MessageResponseType       = 2
-       MessageCookieResponseType = 3
-       MessageTransportType      = 4
+       MessageInitiationType  = 1
+       MessageResponseType    = 2
+       MessageCookieReplyType = 3
+       MessageTransportType   = 4
+)
+
+const (
+       MessageInitiationSize = 148
+       MessageResponseSize   = 92
 )
 
 /* Type is an 8-bit field, followed by 3 nul bytes,
  * by marshalling the messages in little-endian byteorder
- * we can treat these as a 32-bit int
+ * we can treat these as a 32-bit unsigned int (for now)
  *
  */
 
@@ -63,6 +68,13 @@ type MessageTransport struct {
        Content  []byte
 }
 
+type MessageCookieReply struct {
+       Type     uint32
+       Receiver uint32
+       Nonce    [24]byte
+       Cookie   [blake2s.Size128 + poly1305.TagSize]byte
+}
+
 type Handshake struct {
        state                   int
        mutex                   sync.Mutex
index 8450c1c25046fc004988ebb182247578c37dab8d..dab603b6141c925a688ddd15c0c52fb377a3ae29 100644 (file)
@@ -18,6 +18,17 @@ func assertEqual(t *testing.T, a []byte, b []byte) {
        }
 }
 
+func newDevice(t *testing.T) *Device {
+       var device Device
+       sk, err := newPrivateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+       device.Init()
+       device.SetPrivateKey(sk)
+       return &device
+}
+
 func TestCurveWrappers(t *testing.T) {
        sk1, err := newPrivateKey()
        assertNil(t, err)
@@ -36,17 +47,6 @@ func TestCurveWrappers(t *testing.T) {
        }
 }
 
-func newDevice(t *testing.T) *Device {
-       var device Device
-       sk, err := newPrivateKey()
-       if err != nil {
-               t.Fatal(err)
-       }
-       device.Init()
-       device.SetPrivateKey(sk)
-       return &device
-}
-
 func TestNoiseHandshake(t *testing.T) {
 
        dev1 := newDevice(t)
index 6a879cb3454bf622d40bd404fac2ed0982df9fcf..e192b12d19720388086e7948fc6aa44447523ffb 100644 (file)
@@ -2,7 +2,6 @@ package main
 
 import (
        "errors"
-       "golang.org/x/crypto/blake2s"
        "net"
        "sync"
        "time"
@@ -19,12 +18,10 @@ type Peer struct {
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
-       macKey                      [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
-       cookie                      []byte             // cookie
-       cookieExpire                time.Time
        queueInbound                chan []byte
        queueOutbound               chan *OutboundWorkQueueElement
        queueOutboundRouting        chan []byte
+       mac                         MacStatePeer
 }
 
 func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
@@ -35,6 +32,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        peer.mutex.Lock()
        peer.device = device
        peer.keyPairs.Init()
+       peer.mac.Init(pk)
        peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
 
        // map public key
@@ -53,11 +51,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
        handshake.mutex.Lock()
        handshake.remoteStatic = pk
        handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
-
-       // compute mac key
-
-       peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
-
        handshake.mutex.Unlock()
        peer.mutex.Unlock()
 
index da5905d5491aa6960cf7070b210e5939ff0d03e8..f58d311ac28650779b0b96fd33fe34ed4c7f085e 100644 (file)
@@ -24,6 +24,10 @@ type OutboundWorkQueueElement struct {
        keyPair *KeyPair
 }
 
+func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
+
+}
+
 func (device *Device) SendPacket(packet []byte) {
 
        // lookup peer
@@ -39,7 +43,7 @@ func (device *Device) SendPacket(packet []byte) {
                peer = device.routingTable.LookupIPv6(dst)
 
        default:
-               device.logger.Println("unknown IP version")
+               device.log.Debug.Println("receieved packet with unknown IP version")
                return
        }
 
@@ -146,15 +150,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
 func (peer *Peer) RoutineSequential() {
        for work := range peer.queueOutbound {
                work.wg.Wait()
-
-               // check if dropped ("ghost packet")
-
                if work.packet == nil {
                        continue
                }
-
-               //
-
+               if peer.endpoint == nil {
+                       continue
+               }
+               peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
        }
 }