]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Fixed cookie reply processing bug
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 7 Jul 2017 11:47:09 +0000 (13:47 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 7 Jul 2017 11:47:09 +0000 (13:47 +0200)
src/constants.go
src/device.go
src/handshake.go
src/ip.go
src/macs_device.go
src/macs_peer.go
src/macs_test.go
src/receive.go
src/send.go
src/tun_linux.go

index 053ba4f9be6d7923909584dddee27803c4efb3c8..6fbb7a0cfec35cbb7a256df683c919eaab81fac4 100644 (file)
@@ -12,7 +12,7 @@ const (
        RejectAfterTime         = time.Second * 180
        RejectAfterMessages     = (1 << 64) - (1 << 4) - 1
        KeepaliveTimeout        = time.Second * 10
-       CookieRefreshTime       = time.Second * 2
+       CookieRefreshTime       = time.Minute * 2
        MaxHandshakeAttemptTime = time.Second * 90
 )
 
index c57762394ec69371818efd60881e05f1f01016ec..882d5870e12ed524a9088e4e8097944eac6d9f76 100644 (file)
@@ -25,8 +25,8 @@ type Device struct {
        queue        struct {
                encryption chan *QueueOutboundElement
                decryption chan *QueueInboundElement
+               inbound    chan *QueueInboundElement
                handshake  chan QueueHandshakeElement
-               inbound    chan []byte // inbound queue for TUN
        }
        signal struct {
                stop chan struct{}
@@ -77,10 +77,10 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 
        // create queues
 
-       device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
        device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
+       device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
        device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
-       device.queue.inbound = make(chan []byte, QueueInboundSize)
+       device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
 
        // prepare signals
 
index 88bb8cb8803221c937493fba737f5ae0d9724a80..de607df174d0beff86173cb77a6c54e841f4f0dd 100644 (file)
@@ -112,7 +112,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                                binary.Write(writer, binary.LittleEndian, msg)
                                elem.packet = writer.Bytes()
                                peer.mac.AddMacs(elem.packet)
-                               println(elem)
                                addToOutboundQueue(peer.queue.outbound, elem)
 
                                if attempts == 0 {
index a9685adba450dc5d2b3a39ae4a79bc2399cc9359..36beb9c8131901c520df7da27ea01c59a78f39cb 100644 (file)
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,14 +5,17 @@ import (
 )
 
 const (
-       IPv4version    = 4
-       IPv4offsetSrc  = 12
-       IPv4offsetDst  = IPv4offsetSrc + net.IPv4len
-       IPv4headerSize = 20
+       IPv4version           = 4
+       IPv4offsetTotalLength = 2
+       IPv4offsetSrc         = 12
+       IPv4offsetDst         = IPv4offsetSrc + net.IPv4len
+       IPv4headerSize        = 20
 )
 
 const (
-       IPv6version   = 6
-       IPv6offsetSrc = 8
-       IPv6offsetDst = IPv6offsetSrc + net.IPv6len
+       IPv6version             = 6
+       IPv6offsetPayloadLength = 4
+       IPv6offsetSrc           = 8
+       IPv6offsetDst           = IPv6offsetSrc + net.IPv6len
+       IPv6headerSize          = 40
 )
index deac803d6ae57a0d1fef3e437287e887288e769d..68181e60076ced137100957e3f7a34611f2cdca1 100644 (file)
@@ -15,21 +15,31 @@ type MACStateDevice struct {
        mutex     sync.RWMutex
        refreshed time.Time
        secret    [blake2s.Size]byte
-       keyMac1   [blake2s.Size]byte
+       keyMAC1   [blake2s.Size]byte
+       keyMAC2   [blake2s.Size]byte
        xaead     cipher.AEAD
 }
 
 func (state *MACStateDevice) Init(pk NoisePublicKey) {
        state.mutex.Lock()
        defer state.mutex.Unlock()
+
        func() {
                hsh, _ := blake2s.New256(nil)
                hsh.Write([]byte(WGLabelMAC1))
                hsh.Write(pk[:])
-               hsh.Sum(state.keyMac1[:0])
+               hsh.Sum(state.keyMAC1[:0])
        }()
-       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
-       state.refreshed = time.Time{} // never
+
+       func() {
+               hsh, _ := blake2s.New256(nil)
+               hsh.Write([]byte(WGLabelCookie))
+               hsh.Write(pk[:])
+               hsh.Sum(state.keyMAC2[:0])
+       }()
+
+       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMAC2[:])
+       state.refreshed = time.Time{}
 }
 
 func (state *MACStateDevice) CheckMAC1(msg []byte) bool {
@@ -39,7 +49,7 @@ func (state *MACStateDevice) CheckMAC1(msg []byte) bool {
 
        var mac1 [blake2s.Size128]byte
        func() {
-               mac, _ := blake2s.New128(state.keyMac1[:])
+               mac, _ := blake2s.New128(state.keyMAC1[:])
                mac.Write(msg[:startMac1])
                mac.Sum(mac1[:0])
        }()
@@ -117,7 +127,7 @@ func (device *Device) CreateMessageCookieReply(msg []byte, receiver uint32, addr
        startMac1 := size - (blake2s.Size128 * 2)
        startMac2 := size - blake2s.Size128
 
-       M := msg[startMac1:startMac2]
+       mac1 := msg[startMac1:startMac2]
 
        reply := new(MessageCookieReply)
        reply.Type = MessageCookieReplyType
@@ -127,7 +137,7 @@ func (device *Device) CreateMessageCookieReply(msg []byte, receiver uint32, addr
                state.mutex.RUnlock()
                return nil, err
        }
-       state.xaead.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], M)
+       state.xaead.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], mac1)
        state.mutex.RUnlock()
        return reply, nil
 }
@@ -149,9 +159,11 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
 
        var cookie [blake2s.Size128]byte
        state := &lookup.peer.mac
+
        state.mutex.Lock()
        defer state.mutex.Unlock()
-       _, err := state.xaead.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], state.lastMac1[:])
+
+       _, err := state.xaead.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], state.lastMAC1[:])
        if err != nil {
                return false
        }
index 5e9ebae47d930d1aa54171b3df71e000195b4351..16a7a87eb0c4feadedea0c1e6dff9853efd78731 100644 (file)
@@ -13,21 +13,31 @@ type MACStatePeer struct {
        mutex     sync.RWMutex
        cookieSet time.Time
        cookie    [blake2s.Size128]byte
-       lastMac1  [blake2s.Size128]byte
-       keyMac1   [blake2s.Size]byte
+       lastMAC1  [blake2s.Size128]byte
+       keyMAC1   [blake2s.Size]byte
+       keyMAC2   [blake2s.Size]byte
        xaead     cipher.AEAD
 }
 
 func (state *MACStatePeer) Init(pk NoisePublicKey) {
        state.mutex.Lock()
        defer state.mutex.Unlock()
+
        func() {
                hsh, _ := blake2s.New256(nil)
                hsh.Write([]byte(WGLabelMAC1))
                hsh.Write(pk[:])
-               hsh.Sum(state.keyMac1[:0])
+               hsh.Sum(state.keyMAC1[:0])
        }()
-       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMac1[:])
+
+       func() {
+               hsh, _ := blake2s.New256(nil)
+               hsh.Write([]byte(WGLabelCookie))
+               hsh.Write(pk[:])
+               hsh.Sum(state.keyMAC2[:0])
+       }()
+
+       state.xaead, _ = chacha20poly1305.NewXCipher(state.keyMAC2[:])
        state.cookieSet = time.Time{} // never
 }
 
@@ -50,11 +60,11 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
        // set mac1
 
        func() {
-               mac, _ := blake2s.New128(state.keyMac1[:])
+               mac, _ := blake2s.New128(state.keyMAC1[:])
                mac.Write(msg[:startMac1])
-               mac.Sum(state.lastMac1[:0])
+               mac.Sum(state.lastMAC1[:0])
        }()
-       copy(mac1, state.lastMac1[:])
+       copy(mac1, state.lastMAC1[:])
 
        // set mac2
 
index a2a65035586b04885e1261dd9463bfdfca6a9b95..b7d5115c1ae05eed9d5a52d4b4082d18b8e63717 100644 (file)
@@ -1,7 +1,6 @@
 package main
 
 import (
-       "bytes"
        "net"
        "testing"
        "testing/quick"
@@ -17,8 +16,8 @@ func TestMAC1(t *testing.T) {
        peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
        peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
 
-       assertEqual(t, peer1.mac.keyMac1[:], dev1.mac.keyMac1[:])
-       assertEqual(t, peer2.mac.keyMac1[:], dev2.mac.keyMac1[:])
+       assertEqual(t, peer1.mac.keyMAC1[:], dev1.mac.keyMAC1[:])
+       assertEqual(t, peer2.mac.keyMAC1[:], dev2.mac.keyMAC1[:])
 
        msg1 := make([]byte, 256)
        copy(msg1, []byte("some content"))
@@ -52,17 +51,15 @@ func TestMACs(t *testing.T) {
                if addr.Port < 0 {
                        return true
                }
+
                addr.Port &= 0xffff
 
                if len(msg) < 32 {
                        return true
                }
-               if bytes.Compare(peer1.mac.keyMac1[:], device1.mac.keyMac1[:]) != 0 {
-                       return false
-               }
-               if bytes.Compare(peer2.mac.keyMac1[:], device2.mac.keyMac1[:]) != 0 {
-                       return false
-               }
+
+               assertEqual(t, peer1.mac.keyMAC1[:], device1.mac.keyMAC1[:])
+               assertEqual(t, peer2.mac.keyMAC1[:], device2.mac.keyMAC1[:])
 
                device2.indices.Insert(receiver, IndexTableEntry{
                        peer:      peer1,
@@ -83,17 +80,17 @@ func TestMACs(t *testing.T) {
                        return false
                }
 
-               if device2.ConsumeMessageCookieReply(cr) == false {
+               if !device2.ConsumeMessageCookieReply(cr) {
                        return false
                }
 
                // test MAC1 + MAC2
 
                peer1.mac.AddMacs(msg)
-               if device1.mac.CheckMAC1(msg) == false {
+               if !device1.mac.CheckMAC1(msg) {
                        return false
                }
-               if device1.mac.CheckMAC2(msg, &addr) == false {
+               if !device1.mac.CheckMAC2(msg, &addr) {
                        return false
                }
 
@@ -107,6 +104,8 @@ func TestMACs(t *testing.T) {
                        return false
                }
 
+               t.Log("Passed")
+
                return true
        }
 
index 50789a10edf0516c998fb6a486e3be81294c37e6..7b16dc5716c41eb8df305d1c32ae74562f6d3547 100644 (file)
@@ -55,6 +55,23 @@ func addToInboundQueue(
        }
 }
 
+func addToHandshakeQueue(
+       queue chan QueueHandshakeElement,
+       element QueueHandshakeElement,
+) {
+       for {
+               select {
+               case queue <- element:
+                       return
+               default:
+                       select {
+                       case <-queue:
+                       default:
+                       }
+               }
+       }
+}
+
 func (device *Device) RoutineReceiveIncomming() {
 
        debugLog := device.log.Debug
@@ -62,7 +79,7 @@ func (device *Device) RoutineReceiveIncomming() {
 
        errorLog := device.log.Error
 
-       var buffer []byte // unsliced buffer
+       var buffer []byte
 
        for {
 
@@ -116,7 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
 
                                busy := len(device.queue.handshake) > QueueHandshakeBusySize
                                if busy && !device.mac.CheckMAC2(packet, raddr) {
-                                       sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" follows "type"
+                                       sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
                                        reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
                                        if err != nil {
                                                errorLog.Println("Failed to create cookie reply:", err)
@@ -134,12 +151,15 @@ func (device *Device) RoutineReceiveIncomming() {
 
                                // add to handshake queue
 
+                               addToHandshakeQueue(
+                                       device.queue.handshake,
+                                       QueueHandshakeElement{
+                                               msgType: msgType,
+                                               packet:  packet,
+                                               source:  raddr,
+                                       },
+                               )
                                buffer = nil
-                               device.queue.handshake <- QueueHandshakeElement{
-                                       msgType: msgType,
-                                       packet:  packet,
-                                       source:  raddr,
-                               }
 
                        case MessageCookieReplyType:
 
@@ -293,7 +313,21 @@ func (device *Device) RoutineHandshake() {
                                        )
                                        return
                                }
-                               logDebug.Println("Recieved valid initiation message for peer", peer.id)
+
+                               // create response
+
+                               response, err := device.CreateMessageResponse(peer)
+                               if err != nil {
+                                       logError.Println("Failed to create response message:", err)
+                                       return
+                               }
+                               outElem := device.NewOutboundElement()
+                               writer := bytes.NewBuffer(outElem.data[:0])
+                               binary.Write(writer, binary.LittleEndian, response)
+                               elem.packet = writer.Bytes()
+                               peer.mac.AddMacs(elem.packet)
+                               device.log.Debug.Println(elem.packet)
+                               addToOutboundQueue(peer.queue.outbound, outElem)
 
                        case MessageResponseType:
 
@@ -352,29 +386,53 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        return
                case elem = <-peer.queue.inbound:
                }
-
                elem.mutex.Lock()
-               if elem.IsDropped() {
-                       continue
-               }
 
-               // check for replay
+               // process IP packet
+
+               func() {
+                       if elem.IsDropped() {
+                               return
+                       }
 
-               // update timers
+                       // check for replay
 
-               // check for keep-alive
+                       // update timers
 
-               if len(elem.packet) == 0 {
-                       continue
-               }
+                       // refresh key material
 
-               // strip padding
+                       // check for keep-alive
 
-               // insert into inbound TUN queue
+                       if len(elem.packet) == 0 {
+                               return
+                       }
 
-               device.queue.inbound <- elem.packet
+                       // strip padding
 
-               // update key material
+                       switch elem.packet[0] >> 4 {
+                       case IPv4version:
+                               if len(elem.packet) < IPv4headerSize {
+                                       return
+                               }
+                               field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+                               length := binary.BigEndian.Uint16(field)
+                               elem.packet = elem.packet[:length]
+
+                       case IPv6version:
+                               if len(elem.packet) < IPv6headerSize {
+                                       return
+                               }
+                               field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+                               length := binary.BigEndian.Uint16(field)
+                               length += IPv6headerSize
+                               elem.packet = elem.packet[:length]
+
+                       default:
+                               device.log.Debug.Println("Receieved packet with unknown IP version")
+                               return
+                       }
+                       addToInboundQueue(device.queue.inbound, elem)
+               }()
        }
 }
 
@@ -387,8 +445,8 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
                select {
                case <-device.signal.stop:
                        return
-               case packet := <-device.queue.inbound:
-                       _, err := tun.Write(packet)
+               case elem := <-device.queue.inbound:
+                       _, err := tun.Write(elem.packet)
                        if err != nil {
                                logError.Println("Failed to write packet to TUN device:", err)
                        }
index 405366905e488a9235b2bdd043ab5745a10b4b45..d1de44abd4c90d49499e431a51c0486085655d5c 100644 (file)
@@ -28,13 +28,13 @@ import (
  *
  * If the element is inserted into the "encryption queue",
  * the content is preceeded by enough "junk" to contain the header
- * (to allow the constuction of transport messages in-place)
+ * (to allow the construction of transport messages in-place)
  */
 type QueueOutboundElement struct {
        state   uint32
        mutex   sync.Mutex
        data    [MaxMessageSize]byte
-       packet  []byte   // slice of packet (sending)
+       packet  []byte   // slice of "data" (always!)
        nonce   uint64   // nonce for encryption
        keyPair *KeyPair // key-pair for encryption
        peer    *Peer    // related peer
@@ -51,8 +51,12 @@ func (peer *Peer) FlushNonceQueue() {
        }
 }
 
+/*
+ * Assumption: The mutex of the returned element is released
+ */
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
-       elem := new(QueueOutboundElement) // TODO: profile, consider sync.Pool
+       // TODO: profile, consider sync.Pool
+       elem := new(QueueOutboundElement)
        return elem
 }
 
@@ -160,9 +164,8 @@ func (peer *Peer) RoutineNonce() {
        var elem *QueueOutboundElement
 
        device := peer.device
-       logger := device.log.Debug
-
-       logger.Println("Routine, nonce worker, started for peer", peer.id)
+       logDebug := device.log.Debug
+       logDebug.Println("Routine, nonce worker, started for peer", peer.id)
 
        func() {
 
@@ -193,18 +196,18 @@ func (peer *Peer) RoutineNonce() {
                                                break
                                        }
                                }
-                               logger.Println("Key pair:", keyPair)
+                               logDebug.Println("Key pair:", keyPair)
 
                                sendSignal(peer.signal.handshakeBegin)
-                               logger.Println("Waiting for key-pair, peer", peer.id)
+                               logDebug.Println("Waiting for key-pair, peer", peer.id)
 
                                select {
                                case <-peer.signal.newKeyPair:
-                                       logger.Println("Key-pair negotiated for peer", peer.id)
+                                       logDebug.Println("Key-pair negotiated for peer", peer.id)
                                        goto NextPacket
 
                                case <-peer.signal.flushNonceQueue:
-                                       logger.Println("Clearing queue for peer", peer.id)
+                                       logDebug.Println("Clearing queue for peer", peer.id)
                                        peer.FlushNonceQueue()
                                        elem = nil
                                        goto NextPacket
@@ -233,8 +236,6 @@ func (peer *Peer) RoutineNonce() {
                        }
                }
        }()
-
-       logger.Println("Routine, nonce worker, stopped for peer", peer.id)
 }
 
 /* Encrypts the elements in the queue
@@ -265,20 +266,16 @@ func (device *Device) RoutineEncryption() {
 
                // encrypt content
 
-               func() {
-                       binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
-                       work.packet = work.keyPair.send.Seal(
-                               work.packet[:0],
-                               nonce[:],
-                               work.packet,
-                               nil,
-                       )
-                       work.mutex.Unlock()
-               }()
-
-               // reslice to include header
-
-               work.packet = work.data[:MessageTransportHeaderSize+len(work.packet)]
+               binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
+               work.packet = work.keyPair.send.Seal(
+                       work.packet[:0],
+                       nonce[:],
+                       work.packet,
+                       nil,
+               )
+               length := MessageTransportHeaderSize + len(work.packet)
+               work.packet = work.data[:length]
+               work.mutex.Unlock()
 
                // refresh key if necessary
 
@@ -292,15 +289,15 @@ func (device *Device) RoutineEncryption() {
  * The routine terminates then the outbound queue is closed.
  */
 func (peer *Peer) RoutineSequentialSender() {
-       logger := peer.device.log.Debug
-       logger.Println("Routine, sequential sender, started for peer", peer.id)
+       logDebug := peer.device.log.Debug
+       logDebug.Println("Routine, sequential sender, started for peer", peer.id)
 
        device := peer.device
 
        for {
                select {
                case <-peer.signal.stop:
-                       logger.Println("Routine, sequential sender, stopped for peer", peer.id)
+                       logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
                        return
                case work := <-peer.queue.outbound:
                        if work.IsDropped() {
@@ -316,7 +313,7 @@ func (peer *Peer) RoutineSequentialSender() {
                                defer peer.mutex.RUnlock()
 
                                if peer.endpoint == nil {
-                                       logger.Println("No endpoint for peer:", peer.id)
+                                       logDebug.Println("No endpoint for peer:", peer.id)
                                        return
                                }
 
@@ -324,7 +321,7 @@ func (peer *Peer) RoutineSequentialSender() {
                                defer device.net.mutex.RUnlock()
 
                                if device.net.conn == nil {
-                                       logger.Println("No source for device")
+                                       logDebug.Println("No source for device")
                                        return
                                }
 
index a0bff81f9d7d0c6735808d3650bca8d677e16307..d16a966c7f8074ff6f51f92ec8d64c3dd6ad01dc 100644 (file)
@@ -74,6 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
        return &NativeTun{
                fd:   fd,
                name: newName,
-               mtu:  1500, // TODO: FIX
+               mtu:  0,
        }, nil
 }