]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added source verification
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 8 Jul 2017 07:23:10 +0000 (09:23 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 8 Jul 2017 07:23:10 +0000 (09:23 +0200)
src/config.go
src/device.go
src/peer.go
src/receive.go
src/send.go

index 8281581a194035bacf2be525c25a3c5f2349577c..4edaa2e1c749f1e81f257a523749b00dfa0d3829 100644 (file)
@@ -61,8 +61,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
                        if peer.endpoint != nil {
                                send("endpoint=" + peer.endpoint.String())
                        }
-                       send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
-                       send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
+                       send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
+                       send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
                        send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
                        for _, ip := range device.routingTable.AllowedIPs(peer) {
                                send("allowed_ip=" + ip.String())
@@ -73,7 +73,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
        // send lines
 
        for _, line := range lines {
-               device.log.Debug.Println("Response:", line)
                _, err := socket.WriteString(line + "\n")
                if err != nil {
                        return err
index 882d5870e12ed524a9088e4e8097944eac6d9f76..0564068daba79151feb13b8d09baa110ad27dca2 100644 (file)
@@ -31,10 +31,16 @@ type Device struct {
        signal struct {
                stop chan struct{}
        }
-       peers map[NoisePublicKey]*Peer
-       mac   MACStateDevice
+       congestionState int32 // used as an atomic bool
+       peers           map[NoisePublicKey]*Peer
+       mac             MACStateDevice
 }
 
+const (
+       CongestionStateUnderLoad = iota
+       CongestionStateOkay
+)
+
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
        device.mutex.Lock()
        defer device.mutex.Unlock()
@@ -93,6 +99,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
                go device.RoutineDecryption()
                go device.RoutineHandshake()
        }
+       go device.RoutineBusyMonitor()
        go device.RoutineReadFromTUN(tun)
        go device.RoutineReceiveIncomming()
        go device.RoutineWriteToTUN(tun)
index e3c80602de2076ed8793c4a6f03c7150b4410d10..fadc43f5ea99d0a3f4d2882c1df934194d27de82 100644 (file)
@@ -17,8 +17,8 @@ type Peer struct {
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
-       tx_bytes                    uint64
-       rx_bytes                    uint64
+       txBytes                     uint64
+       rxBytes                     uint64
        time                        struct {
                lastSend      time.Time // last send message
                lastHandshake time.Time // last completed handshake
index 7b16dc5716c41eb8df305d1c32ae74562f6d3547..c788dcf99b596e2251f671d4d6d191d5bc6177c8 100644 (file)
@@ -72,12 +72,48 @@ func addToHandshakeQueue(
        }
 }
 
-func (device *Device) RoutineReceiveIncomming() {
+/* Routine determining the busy state of the interface
+ *
+ * TODO: prehaps nicer to do this in response to events
+ * TODO: more well reasoned definition of "busy"
+ */
+func (device *Device) RoutineBusyMonitor() {
+       samples := 0
+       interval := time.Second
+       for timer := time.NewTimer(interval); ; {
+
+               select {
+               case <-device.signal.stop:
+                       return
+               case <-timer.C:
+               }
+
+               // compute busy heuristic
+
+               if len(device.queue.handshake) > QueueHandshakeBusySize {
+                       samples += 1
+               } else if samples > 0 {
+                       samples -= 1
+               }
+               samples %= 30
+               busy := samples > 5
+
+               // update busy state
+
+               if busy {
+                       atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad)
+               } else {
+                       atomic.StoreInt32(&device.congestionState, CongestionStateOkay)
+               }
+
+               timer.Reset(interval)
+       }
+}
 
-       debugLog := device.log.Debug
-       debugLog.Println("Routine, receive incomming, started")
+func (device *Device) RoutineReceiveIncomming() {
 
-       errorLog := device.log.Error
+       logDebug := device.log.Debug
+       logDebug.Println("Routine, receive incomming, started")
 
        var buffer []byte
 
@@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() {
 
                        case MessageInitiationType, MessageResponseType:
 
-                               // verify mac1
-
-                               if !device.mac.CheckMAC1(packet) {
-                                       debugLog.Println("Received packet with invalid mac1")
-                                       return
-                               }
-
-                               // check if busy, TODO: refine definition of "busy"
-
-                               busy := len(device.queue.handshake) > QueueHandshakeBusySize
-                               if busy && !device.mac.CheckMAC2(packet, raddr) {
-                                       sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
-                                       reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
-                                       if err != nil {
-                                               errorLog.Println("Failed to create cookie reply:", err)
-                                               return
-                                       }
-                                       writer := bytes.NewBuffer(packet[:0])
-                                       binary.Write(writer, binary.LittleEndian, reply)
-                                       packet = writer.Bytes()
-                                       _, err = device.net.conn.WriteToUDP(packet, raddr)
-                                       if err != nil {
-                                               debugLog.Println("Failed to send cookie reply:", err)
-                                       }
-                                       return
-                               }
-
                                // add to handshake queue
 
                                addToHandshakeQueue(
@@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() {
                                reader := bytes.NewReader(packet)
                                err := binary.Read(reader, binary.LittleEndian, &reply)
                                if err != nil {
-                                       debugLog.Println("Failed to decode cookie reply")
+                                       logDebug.Println("Failed to decode cookie reply")
                                        return
                                }
                                device.ConsumeMessageCookieReply(&reply)
@@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() {
 
                        default:
                                // unknown message type
-                               debugLog.Println("Got unknown message from:", raddr)
+                               logDebug.Println("Got unknown message from:", raddr)
                        }
                }()
        }
@@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() {
 
                func() {
 
+                       // verify mac1
+
+                       if !device.mac.CheckMAC1(elem.packet) {
+                               logDebug.Println("Received packet with invalid mac1")
+                               return
+                       }
+
+                       // verify mac2
+
+                       busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad
+
+                       if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
+                               sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+                               reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
+                               if err != nil {
+                                       logError.Println("Failed to create cookie reply:", err)
+                                       return
+                               }
+                               writer := bytes.NewBuffer(elem.packet[:0])
+                               binary.Write(writer, binary.LittleEndian, reply)
+                               elem.packet = writer.Bytes()
+                               _, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
+                               if err != nil {
+                                       logDebug.Println("Failed to send cookie reply:", err)
+                               }
+                               return
+                       }
+
+                       // ratelimit
+
+                       // handle messages
+
                        switch elem.msgType {
                        case MessageInitiationType:
 
@@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() {
                                        logError.Println("Failed to create response message:", err)
                                        return
                                }
+
                                outElem := device.NewOutboundElement()
                                writer := bytes.NewBuffer(outElem.data[:0])
                                binary.Write(writer, binary.LittleEndian, response)
                                elem.packet = writer.Bytes()
                                peer.mac.AddMacs(elem.packet)
-                               device.log.Debug.Println(elem.packet)
                                addToOutboundQueue(peer.queue.outbound, outElem)
 
                        case MessageResponseType:
@@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                }
                elem.mutex.Lock()
 
-               // process IP packet
+               // process packet
 
                func() {
                        if elem.IsDropped() {
@@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() {
                                return
                        }
 
-                       // strip padding
+                       // verify source and strip padding
 
                        switch elem.packet[0] >> 4 {
                        case IPv4version:
+
+                               // strip padding
+
                                if len(elem.packet) < IPv4headerSize {
                                        return
                                }
+
                                field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
                                length := binary.BigEndian.Uint16(field)
                                elem.packet = elem.packet[:length]
 
+                               // verify IPv4 source
+
+                               dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+                               if device.routingTable.LookupIPv4(dst) != peer {
+                                       return
+                               }
+
                        case IPv6version:
+
+                               // strip padding
+
                                if len(elem.packet) < IPv6headerSize {
                                        return
                                }
+
                                field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
                                length := binary.BigEndian.Uint16(field)
                                length += IPv6headerSize
                                elem.packet = elem.packet[:length]
 
+                               // verify IPv6 source
+
+                               dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+                               if device.routingTable.LookupIPv6(dst) != peer {
+                                       return
+                               }
+
                        default:
                                device.log.Debug.Println("Receieved packet with unknown IP version")
                                return
                        }
+
+                       atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet)))
                        addToInboundQueue(device.queue.inbound, elem)
                }()
        }
index d1de44abd4c90d49499e431a51c0486085655d5c..a02f5cb3b6b9cb38a2add51e869d14dc5252e023 100644 (file)
@@ -329,7 +329,7 @@ func (peer *Peer) RoutineSequentialSender() {
                                if err != nil {
                                        return
                                }
-                               atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
+                               atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
 
                                // shift keep-alive timer