]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Improved receive.go
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 11 Aug 2017 14:18:20 +0000 (16:18 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 11 Aug 2017 14:18:20 +0000 (16:18 +0200)
- Fixed configuration listen-port semantics
- Improved receive.go code for updating listen port
- Updated under load detection, how follows the kernel space implementation
- Fixed trie bug accidentally introduced in last commit
- Added interface name to log (format still subject to change)
- Can now configure the logging level using the LOG_LEVEL variable
- Begin porting netsh.sh tests
- A number of smaller changes

16 files changed:
src/config.go
src/conn.go [new file with mode: 0644]
src/constants.go
src/device.go
src/helper_test.go
src/index.go
src/logger.go
src/macs_test.go
src/main.go
src/misc.go
src/noise_test.go
src/receive.go
src/send.go
src/tests/netns.sh [new file with mode: 0755]
src/trie.go
src/uapi_linux.go

index d952a3a9b57c120ba826fc34d653f9b57abc59f8..474134b933369fcf96f6ef7f078a84ab5406c6dd 100644 (file)
@@ -28,6 +28,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        // create lines
 
        device.mutex.RLock()
+       device.net.mutex.RLock()
 
        lines := make([]string, 0, 100)
        send := func(line string) {
@@ -38,7 +39,9 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                send("private_key=" + device.privateKey.ToHex())
        }
 
-       send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
+       if device.net.addr != nil {
+               send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
+       }
 
        for _, peer := range device.peers {
                func() {
@@ -68,6 +71,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                }()
        }
 
+       device.net.mutex.RUnlock()
        device.mutex.RUnlock()
 
        // send lines
@@ -84,38 +88,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        return nil
 }
 
-func updateUDPConn(device *Device) error {
-       var err error
-       netc := &device.net
-       netc.mutex.Lock()
-
-       // close existing connection
-
-       if netc.conn != nil {
-               netc.conn.Close()
-               netc.conn = nil
-       }
-
-       // open new existing connection
-
-       conn, err := net.ListenUDP("udp", netc.addr)
-       if err == nil {
-               netc.conn = conn
-               signalSend(device.signal.newUDPConn)
-       }
-
-       netc.mutex.Unlock()
-       return err
-}
-
-func closeUDPConn(device *Device) {
-       device.net.mutex.Lock()
-       device.net.conn = nil
-       device.net.mutex.Unlock()
-       println("send signal")
-       signalSend(device.signal.newUDPConn)
-}
-
 func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        scanner := bufio.NewScanner(socket)
        logInfo := device.log.Info
@@ -166,13 +138,22 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set listen_port:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
+
+                               addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
+                               if err != nil {
+                                       logError.Println("Failed to set listen_port:", err)
+                                       return &IPCError{Code: ipcErrorInvalid}
+                               }
+
                                netc := &device.net
                                netc.mutex.Lock()
-                               if netc.addr.Port != int(port) {
-                                       netc.addr.Port = int(port)
-                               }
+                               netc.addr = addr
                                netc.mutex.Unlock()
-                               updateUDPConn(device)
+                               err = updateUDPConn(device)
+                               if err != nil {
+                                       logError.Println("Failed to set listen_port:", err)
+                                       return &IPCError{Code: ipcErrorIO}
+                               }
 
                                // TODO: Clear source address of all peers
 
@@ -298,7 +279,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                                logError.Println("Failed to get tun device status:", err)
                                                return &IPCError{Code: ipcErrorIO}
                                        }
-                                       if atomic.LoadInt32(&device.isUp) == AtomicTrue && !dummy {
+                                       if device.tun.isUp.Get() && !dummy {
                                                peer.SendKeepAlive()
                                        }
                                }
diff --git a/src/conn.go b/src/conn.go
new file mode 100644 (file)
index 0000000..f6472e9
--- /dev/null
@@ -0,0 +1,40 @@
+package main
+
+import (
+       "net"
+)
+
+func updateUDPConn(device *Device) error {
+       var err error
+       netc := &device.net
+       netc.mutex.Lock()
+
+       // close existing connection
+
+       if netc.conn != nil {
+               netc.conn.Close()
+       }
+
+       // open new connection
+
+       if device.tun.isUp.Get() {
+               conn, err := net.ListenUDP("udp", netc.addr)
+               if err == nil {
+                       netc.conn = conn
+                       signalSend(device.signal.newUDPConn)
+               }
+       }
+
+       netc.mutex.Unlock()
+       return err
+}
+
+func closeUDPConn(device *Device) {
+       netc := &device.net
+       netc.mutex.Lock()
+       if netc.conn != nil {
+               netc.conn.Close()
+       }
+       netc.mutex.Unlock()
+       signalSend(device.signal.newUDPConn)
+}
index 37603e8eb93b4408e0c45f19a470d168260c9601..d3666108e3048ae56fd9c801fc6bf6591cceeebf 100644 (file)
@@ -26,11 +26,15 @@ const (
 /* Implementation specific constants */
 
 const (
-       QueueOutboundSize      = 1024
-       QueueInboundSize       = 1024
-       QueueHandshakeSize     = 1024
-       QueueHandshakeBusySize = QueueHandshakeSize / 8
-       MinMessageSize         = MessageTransportSize // size of keep-alive
-       MaxMessageSize         = ((1 << 16) - 1) + MessageTransportHeaderSize
-       MaxPeers               = 1 << 16
+       QueueOutboundSize  = 1024
+       QueueInboundSize   = 1024
+       QueueHandshakeSize = 1024
+       MinMessageSize     = MessageTransportSize // size of keep-alive
+       MaxMessageSize     = ((1 << 16) - 1) + MessageTransportHeaderSize
+       MaxPeers           = 1 << 16
+)
+
+const (
+       UnderLoadQueueSize = QueueHandshakeSize / 8
+       UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
 )
index 4aa90e345a8ff207a65847fe8577287e786f6449..781b52502d492d3919a5864da5f70c7c9c9f4b56 100644 (file)
@@ -5,20 +5,22 @@ import (
        "runtime"
        "sync"
        "sync/atomic"
+       "time"
 )
 
 type Device struct {
-       mtu       int32
-       tun       TUNDevice
        log       *Logger // collection of loggers for levels
        idCounter uint    // for assigning debug ids to peers
        fwMark    uint32
-       pool      struct {
-               // pools objects for reuse
+       tun       struct {
+               device TUNDevice
+               isUp   AtomicBool
+               mtu    int32
+       }
+       pool struct {
                messageBuffers sync.Pool
        }
        net struct {
-               // seperate for performance reasons
                mutex sync.RWMutex
                addr  *net.UDPAddr // UDP source address
                conn  *net.UDPConn // UDP "connection"
@@ -35,13 +37,12 @@ type Device struct {
        }
        signal struct {
                stop       chan struct{} // halts all go routines
-               newUDPConn chan struct{} // a net.conn was set
+               newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
        }
-       isUp        int32 // atomic bool: interface is up
-       underLoad   int32 // atomic bool: device is under load
-       ratelimiter Ratelimiter
-       peers       map[NoisePublicKey]*Peer
-       mac         MACStateDevice
+       underLoadUntil atomic.Value
+       ratelimiter    Ratelimiter
+       peers          map[NoisePublicKey]*Peer
+       mac            MACStateDevice
 }
 
 /* Warning:
@@ -58,6 +59,23 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
        peer.Close()
 }
 
+func (device *Device) IsUnderLoad() bool {
+
+       // check if currently under load
+
+       now := time.Now()
+       underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+       if underLoad {
+               device.underLoadUntil.Store(now.Add(time.Second))
+               return true
+       }
+
+       // check if recently under load
+
+       until := device.underLoadUntil.Load().(time.Time)
+       return until.After(now)
+}
+
 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
        device.mutex.Lock()
        defer device.mutex.Unlock()
@@ -115,20 +133,13 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        device.mutex.Lock()
        defer device.mutex.Unlock()
 
-       device.tun = tun
-       device.log = NewLogger(logLevel)
+       device.log = NewLogger(logLevel, "("+tun.Name()+") ")
        device.peers = make(map[NoisePublicKey]*Peer)
+       device.tun.device = tun
        device.indices.Init()
        device.ratelimiter.Init()
        device.routingTable.Reset()
-
-       // listen
-
-       device.net.mutex.Lock()
-       device.net.conn, _ = net.ListenUDP("udp", device.net.addr)
-       addr := device.net.conn.LocalAddr()
-       device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
-       device.net.mutex.Unlock()
+       device.underLoadUntil.Store(time.Time{})
 
        // setup pools
 
@@ -157,42 +168,43 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
                go device.RoutineHandshake()
        }
 
-       go device.RoutineBusyMonitor()
-       go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
-       go device.RoutineReceiveIncomming()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
+       go device.RoutineReadFromTUN()
+       go device.RoutineReceiveIncomming()
 
        return device
 }
 
 func (device *Device) RoutineTUNEventReader() {
-       events := device.tun.Events()
+       logInfo := device.log.Info
        logError := device.log.Error
 
+       events := device.tun.device.Events()
+
        for event := range events {
                if event&TUNEventMTUUpdate != 0 {
-                       mtu, err := device.tun.MTU()
+                       mtu, err := device.tun.device.MTU()
                        if err != nil {
                                logError.Println("Failed to load updated MTU of device:", err)
                        } else {
                                if mtu+MessageTransportSize > MaxMessageSize {
                                        mtu = MaxMessageSize - MessageTransportSize
                                }
-                               atomic.StoreInt32(&device.mtu, int32(mtu))
+                               atomic.StoreInt32(&device.tun.mtu, int32(mtu))
                        }
                }
 
                if event&TUNEventUp != 0 {
-                       println("handle 1")
-                       atomic.StoreInt32(&device.isUp, AtomicTrue)
+                       device.tun.isUp.Set(true)
                        updateUDPConn(device)
-                       println("handle 2", device.net.conn)
+                       logInfo.Println("Interface set up")
                }
 
                if event&TUNEventDown != 0 {
-                       atomic.StoreInt32(&device.isUp, AtomicFalse)
+                       device.tun.isUp.Set(false)
                        closeUDPConn(device)
+                       logInfo.Println("Interface set down")
                }
        }
 }
@@ -224,6 +236,7 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) Close() {
        device.RemoveAllPeers()
        close(device.signal.stop)
+       closeUDPConn(device)
 }
 
 func (device *Device) WaitChannel() chan struct{} {
index 3838a7cd0928e997a5266a5f422057b24cca8939..fc171e875b4a20c9dfb6ca54ea497a1ed92f4da3 100644 (file)
@@ -12,6 +12,7 @@ type DummyTUN struct {
        name    string
        mtu     int
        packets chan []byte
+       events  chan TUNEvent
 }
 
 func (tun *DummyTUN) Name() string {
@@ -27,6 +28,14 @@ func (tun *DummyTUN) Write(d []byte) (int, error) {
        return len(d), nil
 }
 
+func (tun *DummyTUN) Close() error {
+       return nil
+}
+
+func (tun *DummyTUN) Events() chan TUNEvent {
+       return tun.events
+}
+
 func (tun *DummyTUN) Read(d []byte) (int, error) {
        t := <-tun.packets
        copy(d, t)
index e518b0f8c2adeb39759f02806c0d1227f24d16ce..1ba040ed323447d232bbd69e77b0fabc6bc3f9f6 100644 (file)
@@ -2,8 +2,8 @@ package main
 
 import (
        "crypto/rand"
+       "encoding/binary"
        "sync"
-       "unsafe"
 )
 
 /* Index=0 is reserved for unset indecies
@@ -24,7 +24,8 @@ type IndexTable struct {
 func randUint32() (uint32, error) {
        var buff [4]byte
        _, err := rand.Read(buff[:])
-       return *((*uint32)(unsafe.Pointer(&buff))), err
+       value := binary.LittleEndian.Uint32(buff[:])
+       return value, err
 }
 
 func (table *IndexTable) Init() {
index 9fe73b43196e023b347a7045cc88046ce2b31924..0872ef93d9d9c98fb35ccc8c65e6211e7c8f8555 100644 (file)
@@ -19,7 +19,7 @@ type Logger struct {
        Error *log.Logger
 }
 
-func NewLogger(level int) *Logger {
+func NewLogger(level int, prepend string) *Logger {
        output := os.Stdout
        logger := new(Logger)
 
@@ -34,16 +34,16 @@ func NewLogger(level int) *Logger {
        }()
 
        logger.Debug = log.New(logDebug,
-               "DEBUG: ",
+               "DEBUG: "+prepend,
                log.Ldate|log.Ltime|log.Lshortfile,
        )
 
        logger.Info = log.New(logInfo,
-               "INFO: ",
+               "INFO: "+prepend,
                log.Ldate|log.Ltime,
        )
        logger.Error = log.New(logErr,
-               "ERROR: ",
+               "ERROR: "+prepend,
                log.Ldate|log.Ltime,
        )
        return logger
index b7d5115c1ae05eed9d5a52d4b4082d18b8e63717..3575ccb9388dcd1724b942962044a8489172dbb0 100644 (file)
@@ -13,8 +13,8 @@ func TestMAC1(t *testing.T) {
        defer dev1.Close()
        defer dev2.Close()
 
-       peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
-       peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
+       peer1, _ := dev2.NewPeer(dev1.privateKey.publicKey())
+       peer2, _ := dev1.NewPeer(dev2.privateKey.publicKey())
 
        assertEqual(t, peer1.mac.keyMAC1[:], dev1.mac.keyMAC1[:])
        assertEqual(t, peer2.mac.keyMAC1[:], dev2.mac.keyMAC1[:])
@@ -45,8 +45,8 @@ func TestMACs(t *testing.T) {
                defer device1.Close()
                defer device2.Close()
 
-               peer1 := device2.NewPeer(device1.privateKey.publicKey())
-               peer2 := device1.NewPeer(device2.privateKey.publicKey())
+               peer1, _ := device2.NewPeer(device1.privateKey.publicKey())
+               peer2, _ := device1.NewPeer(device2.privateKey.publicKey())
 
                if addr.Port < 0 {
                        return true
index dde21fb02ad3b5e8466db046f9ed7c495e003ead..196a4c607add27a7f06b0a8ad3e2b67772c754b6 100644 (file)
@@ -65,9 +65,23 @@ func main() {
                return
        }
 
+       // get log level (default: info)
+
+       logLevel := func() int {
+               switch os.Getenv("LOG_LEVEL") {
+               case "debug":
+                       return LogLevelDebug
+               case "info":
+                       return LogLevelInfo
+               case "error":
+                       return LogLevelError
+               }
+               return LogLevelInfo
+       }()
+
        // create wireguard device
 
-       device := NewDevice(tun, LogLevelDebug)
+       device := NewDevice(tun, logLevel)
 
        logInfo := device.log.Info
        logError := device.log.Error
index fc75c0d5d0effe887b2b8312bd8227ee07b95c88..d93849e1b383e77f7e2a19f7efea1a8541313978 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "sync/atomic"
        "time"
 )
 
@@ -8,10 +9,26 @@ import (
  * (since booleans are not natively supported by sync/atomic)
  */
 const (
-       AtomicFalse = iota
+       AtomicFalse = int32(iota)
        AtomicTrue
 )
 
+type AtomicBool struct {
+       flag int32
+}
+
+func (a *AtomicBool) Get() bool {
+       return atomic.LoadInt32(&a.flag) == AtomicTrue
+}
+
+func (a *AtomicBool) Set(val bool) {
+       flag := AtomicFalse
+       if val {
+               flag = AtomicTrue
+       }
+       atomic.StoreInt32(&a.flag, flag)
+}
+
 func min(a uint, b uint) uint {
        if a > b {
                return b
index 86ddce9d27f495a889168df57068630713c75e82..0d7f0e933fadc28765f6b804e7647f6f0b54e4a3 100644 (file)
@@ -31,8 +31,8 @@ func TestNoiseHandshake(t *testing.T) {
        defer dev1.Close()
        defer dev2.Close()
 
-       peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
-       peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
+       peer1, _ := dev2.NewPeer(dev1.privateKey.publicKey())
+       peer2, _ := dev1.NewPeer(dev2.privateKey.publicKey())
 
        assertEqual(
                t,
index 5f469257e6946f3af07868d2bd2ab858aa469056..c47d93c13520e0910f239868ebe17272f5df286f 100644 (file)
@@ -72,43 +72,6 @@ func (device *Device) addToHandshakeQueue(
        }
 }
 
-/* Routine determining the busy state of the interface
- *
- * TODO: Under load for some time
- */
-func (device *Device) RoutineBusyMonitor() {
-       samples := 0
-       interval := time.Second
-       for timer := time.NewTimer(interval); ; {
-
-               select {
-               case <-device.signal.stop:
-                       return
-               case <-timer.C:
-               }
-
-               // compute busy heuristic
-
-               if len(device.queue.handshake) > QueueHandshakeBusySize {
-                       samples += 1
-               } else if samples > 0 {
-                       samples -= 1
-               }
-               samples %= 30
-               busy := samples > 5
-
-               // update busy state
-
-               if busy {
-                       atomic.StoreInt32(&device.underLoad, AtomicTrue)
-               } else {
-                       atomic.StoreInt32(&device.underLoad, AtomicFalse)
-               }
-
-               timer.Reset(interval)
-       }
-}
-
 func (device *Device) RoutineReceiveIncomming() {
 
        logDebug := device.log.Debug
@@ -118,117 +81,121 @@ func (device *Device) RoutineReceiveIncomming() {
 
                // wait for new conn
 
-               var conn *net.UDPConn
+               logDebug.Println("Waiting for udp socket")
 
                select {
+               case <-device.signal.stop:
+                       return
+
                case <-device.signal.newUDPConn:
+
+                       // fetch connection
+
                        device.net.mutex.RLock()
-                       conn = device.net.conn
+                       conn := device.net.conn
                        device.net.mutex.RUnlock()
+                       if conn == nil {
+                               continue
+                       }
 
-               case <-device.signal.stop:
-                       return
-               }
-
-               if conn == nil {
-                       continue
-               }
+                       logDebug.Println("Listening for inbound packets")
 
-               // receive datagrams until closed
+                       // receive datagrams until conn is closed
 
-               buffer := device.GetMessageBuffer()
+                       buffer := device.GetMessageBuffer()
 
-               for {
+                       for {
 
-                       // read next datagram
+                               // read next datagram
 
-                       size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken
+                               size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes
 
-                       if err != nil {
-                               break
-                       }
+                               if err != nil {
+                                       break
+                               }
 
-                       if size < MinMessageSize {
-                               continue
-                       }
+                               if size < MinMessageSize {
+                                       continue
+                               }
 
-                       // check size of packet
+                               // check size of packet
 
-                       packet := buffer[:size]
-                       msgType := binary.LittleEndian.Uint32(packet[:4])
+                               packet := buffer[:size]
+                               msgType := binary.LittleEndian.Uint32(packet[:4])
 
-                       var okay bool
+                               var okay bool
 
-                       switch msgType {
+                               switch msgType {
 
-                       // check if transport
+                               // check if transport
 
-                       case MessageTransportType:
+                               case MessageTransportType:
 
-                               // check size
+                                       // check size
 
-                               if len(packet) < MessageTransportType {
-                                       continue
-                               }
+                                       if len(packet) < MessageTransportType {
+                                               continue
+                                       }
 
-                               // lookup key pair
+                                       // lookup key pair
 
-                               receiver := binary.LittleEndian.Uint32(
-                                       packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
-                               )
-                               value := device.indices.Lookup(receiver)
-                               keyPair := value.keyPair
-                               if keyPair == nil {
-                                       continue
-                               }
+                                       receiver := binary.LittleEndian.Uint32(
+                                               packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+                                       )
+                                       value := device.indices.Lookup(receiver)
+                                       keyPair := value.keyPair
+                                       if keyPair == nil {
+                                               continue
+                                       }
 
-                               // check key-pair expiry
+                                       // check key-pair expiry
 
-                               if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
-                                       continue
-                               }
+                                       if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+                                               continue
+                                       }
 
-                               // create work element
+                                       // create work element
 
-                               peer := value.peer
-                               elem := &QueueInboundElement{
-                                       packet:  packet,
-                                       buffer:  buffer,
-                                       keyPair: keyPair,
-                                       dropped: AtomicFalse,
-                               }
-                               elem.mutex.Lock()
+                                       peer := value.peer
+                                       elem := &QueueInboundElement{
+                                               packet:  packet,
+                                               buffer:  buffer,
+                                               keyPair: keyPair,
+                                               dropped: AtomicFalse,
+                                       }
+                                       elem.mutex.Lock()
 
-                               // add to decryption queues
+                                       // add to decryption queues
 
-                               device.addToInboundQueue(device.queue.decryption, elem)
-                               device.addToInboundQueue(peer.queue.inbound, elem)
-                               buffer = nil
-                               continue
+                                       device.addToInboundQueue(device.queue.decryption, elem)
+                                       device.addToInboundQueue(peer.queue.inbound, elem)
+                                       buffer = device.GetMessageBuffer()
+                                       continue
 
-                       // otherwise it is a handshake related packet
+                               // otherwise it is a handshake related packet
 
-                       case MessageInitiationType:
-                               okay = len(packet) == MessageInitiationSize
+                               case MessageInitiationType:
+                                       okay = len(packet) == MessageInitiationSize
 
-                       case MessageResponseType:
-                               okay = len(packet) == MessageResponseSize
+                               case MessageResponseType:
+                                       okay = len(packet) == MessageResponseSize
 
-                       case MessageCookieReplyType:
-                               okay = len(packet) == MessageCookieReplySize
-                       }
+                               case MessageCookieReplyType:
+                                       okay = len(packet) == MessageCookieReplySize
+                               }
 
-                       if okay {
-                               device.addToHandshakeQueue(
-                                       device.queue.handshake,
-                                       QueueHandshakeElement{
-                                               msgType: msgType,
-                                               buffer:  buffer,
-                                               packet:  packet,
-                                               source:  raddr,
-                                       },
-                               )
-                               buffer = device.GetMessageBuffer()
+                               if okay {
+                                       device.addToHandshakeQueue(
+                                               device.queue.handshake,
+                                               QueueHandshakeElement{
+                                                       msgType: msgType,
+                                                       buffer:  buffer,
+                                                       packet:  packet,
+                                                       source:  raddr,
+                                               },
+                                       )
+                                       buffer = device.GetMessageBuffer()
+                               }
                        }
                }
        }
@@ -326,10 +293,11 @@ func (device *Device) RoutineHandshake() {
                                return
                        }
 
-                       busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
-
-                       if busy {
+                       if device.IsUnderLoad() {
                                if !device.mac.CheckMAC2(elem.packet, elem.source) {
+
+                                       // construct cookie reply
+
                                        sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
                                        reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
                                        if err != nil {
@@ -347,6 +315,7 @@ func (device *Device) RoutineHandshake() {
                                        }
                                        continue
                                }
+
                                if !device.ratelimiter.Allow(elem.source.IP) {
                                        continue
                                }
@@ -577,7 +546,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                // write to tun
 
                atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
-               _, err := device.tun.Write(elem.packet)
+               _, err := device.tun.device.Write(elem.packet)
                device.PutMessageBuffer(elem.buffer)
                if err != nil {
                        logError.Println("Failed to write packet to TUN device:", err)
index cf1f018fe9835a6b93f9b40ae0210b5bd7f9610e..0de3c0a1181b855986b841724ab6fca0b3c028bc 100644 (file)
@@ -137,10 +137,6 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
  */
 func (device *Device) RoutineReadFromTUN() {
 
-       if device.tun == nil {
-               return
-       }
-
        var elem *QueueOutboundElement
 
        logDebug := device.log.Debug
@@ -155,9 +151,8 @@ func (device *Device) RoutineReadFromTUN() {
                        elem = device.NewOutboundElement()
                }
 
-               // TODO: THIS!
                elem.packet = elem.buffer[MessageTransportHeaderSize:]
-               size, err := device.tun.Read(elem.packet)
+               size, err := device.tun.device.Read(elem.packet)
                if err != nil {
                        logError.Println("Failed to read packet from TUN device:", err)
                        device.Close()
@@ -345,7 +340,7 @@ func (device *Device) RoutineEncryption() {
 
                // pad content to MTU size
 
-               mtu := int(atomic.LoadInt32(&device.mtu))
+               mtu := int(atomic.LoadInt32(&device.tun.mtu))
                pad := len(elem.packet) % PaddingMultiple
                if pad > 0 {
                        for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
diff --git a/src/tests/netns.sh b/src/tests/netns.sh
new file mode 100755 (executable)
index 0000000..9f003e2
--- /dev/null
@@ -0,0 +1,350 @@
+#!/bin/bash
+
+# Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+
+# This script tests the below topology:
+#
+# ┌─────────────────────┐   ┌──────────────────────────────────┐   ┌─────────────────────┐
+# │   $ns1 namespace    │   │          $ns0 namespace          │   │   $ns2 namespace    │
+# │                     │   │                                  │   │                     │
+# │┌────────┐           │   │            ┌────────┐            │   │           ┌────────┐│
+# ││  wg1   │───────────┼───┼────────────│   lo   │────────────┼───┼───────────│  wg2   ││
+# │├────────┴──────────┐│   │    ┌───────┴────────┴────────┐   │   │┌──────────┴────────┤│
+# ││192.168.241.1/24   ││   │    │(ns1)         (ns2)      │   │   ││192.168.241.2/24   ││
+# ││fd00::1/24         ││   │    │127.0.0.1:1   127.0.0.1:2│   │   ││fd00::2/24         ││
+# │└───────────────────┘│   │    │[::]:1        [::]:2     │   │   │└───────────────────┘│
+# └─────────────────────┘   │    └─────────────────────────┘   │   └─────────────────────┘
+#                           └──────────────────────────────────┘
+#
+# After the topology is prepared we run a series of TCP/UDP iperf3 tests between the
+# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
+# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
+# details on how this is accomplished.
+set -e
+
+exec 3>&1
+export WG_HIDE_KEYS=never
+netns0="wg-test-$$-0"
+netns1="wg-test-$$-1"
+netns2="wg-test-$$-2"
+program="../wireguard-go"
+export LOG_LEVEL="debug"
+
+pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
+pp() { pretty "" "$*"; "$@"; }
+maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; }
+n0() { pretty 0 "$*"; maybe_exec ip netns exec $netns0 "$@"; }
+n1() { pretty 1 "$*"; maybe_exec ip netns exec $netns1 "$@"; }
+n2() { pretty 2 "$*"; maybe_exec ip netns exec $netns2 "$@"; }
+ip0() { pretty 0 "ip $*"; ip -n $netns0 "$@"; }
+ip1() { pretty 1 "ip $*"; ip -n $netns1 "$@"; }
+ip2() { pretty 2 "ip $*"; ip -n $netns2 "$@"; }
+sleep() { read -t "$1" -N 0 || true; }
+waitiperf() { pretty "${1//*-}" "wait for iperf:5201"; while [[ $(ss -N "$1" -tlp 'sport = 5201') != *iperf3* ]]; do sleep 0.1; done; }
+waitncatudp() { pretty "${1//*-}" "wait for udp:1111"; while [[ $(ss -N "$1" -ulp 'sport = 1111') != *ncat* ]]; do sleep 0.1; done; }
+waitiface() { pretty "${1//*-}" "wait for $2 to come up"; ip netns exec "$1" bash -c "while [[ \$(< \"/sys/class/net/$2/operstate\") != up ]]; do read -t .1 -N 0 || true; done;"; }
+
+cleanup() {
+    n0 wg show
+    set +e
+    exec 2>/dev/null
+    printf "$orig_message_cost" > /proc/sys/net/core/message_cost
+    ip0 link del dev wg1
+    ip1 link del dev wg1
+    ip2 link del dev wg1
+    local to_kill="$(ip netns pids $netns0) $(ip netns pids $netns1) $(ip netns pids $netns2)"
+    [[ -n $to_kill ]] && kill $to_kill
+    pp ip netns del $netns1
+    pp ip netns del $netns2
+    pp ip netns del $netns0
+    exit
+}
+
+orig_message_cost="$(< /proc/sys/net/core/message_cost)"
+trap cleanup EXIT
+printf 0 > /proc/sys/net/core/message_cost
+
+ip netns del $netns0 2>/dev/null || true
+ip netns del $netns1 2>/dev/null || true
+ip netns del $netns2 2>/dev/null || true
+pp ip netns add $netns0
+pp ip netns add $netns1
+pp ip netns add $netns2
+ip0 link set up dev lo
+
+# ip0 link add dev wg1 type wireguard
+n0 $program -f wg1 &
+sleep 1
+ip0 link set wg1 netns $netns1
+
+# ip0 link add dev wg1 type wireguard
+n0 $program -f wg2 &
+sleep 1
+ip0 link set wg2 netns $netns2
+
+key1="$(pp wg genkey)"
+key2="$(pp wg genkey)"
+pub1="$(pp wg pubkey <<<"$key1")"
+pub2="$(pp wg pubkey <<<"$key2")"
+psk="$(pp wg genpsk)"
+[[ -n $key1 && -n $key2 && -n $psk ]]
+
+configure_peers() {
+
+    ip1 addr add 192.168.241.1/24 dev wg1
+    ip1 addr add fd00::1/24 dev wg1
+
+    ip2 addr add 192.168.241.2/24 dev wg2
+    ip2 addr add fd00::2/24 dev wg2
+
+    n0 wg set wg1 \
+        private-key <(echo "$key1") \
+        listen-port 10000 \
+        peer "$pub2" \
+            preshared-key <(echo "$psk") \
+            allowed-ips 192.168.241.2/32,fd00::2/128
+    n0 wg set wg2 \
+        private-key <(echo "$key2") \
+        listen-port 20000 \
+        peer "$pub1" \
+            preshared-key <(echo "$psk") \
+            allowed-ips 192.168.241.1/32,fd00::1/128
+
+    n0 wg showconf wg1
+    n0 wg showconf wg2
+
+    ip1 link set up dev wg1
+    ip2 link set up dev wg2
+}
+configure_peers
+
+tests() {
+    # Ping over IPv4
+    n2 ping -c 10 -f -W 1 192.168.241.1
+    n1 ping -c 10 -f -W 1 192.168.241.2
+
+    # Ping over IPv6
+    n2 ping6 -c 10 -f -W 1 fd00::1
+    n1 ping6 -c 10 -f -W 1 fd00::2
+
+    # TCP over IPv4
+    n2 iperf3 -s -1 -B 192.168.241.2 &
+    waitiperf $netns2
+    n1 iperf3 -Z -n 1G -c 192.168.241.2
+
+    # TCP over IPv6
+    n1 iperf3 -s -1 -B fd00::1 &
+    waitiperf $netns1
+    n2 iperf3 -Z -n 1G -c fd00::1
+
+    # UDP over IPv4
+    n1 iperf3 -s -1 -B 192.168.241.1 &
+    waitiperf $netns1
+    n2 iperf3 -Z -n 1G -b 0 -u -c 192.168.241.1
+
+    # UDP over IPv6
+    n2 iperf3 -s -1 -B fd00::2 &
+    waitiperf $netns2
+    n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2
+}
+
+[[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}"
+big_mtu=$(( 34816 - 1500 + $orig_mtu ))
+
+# Test using IPv4 as outer transport
+n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
+n0 wg set wg2 peer "$pub1" endpoint 127.0.0.1:10000
+n0 wg show
+# Before calling tests, we first make sure that the stats counters are working
+n2 ping -c 10 -f -W 1 192.168.241.1
+{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg2)
+[[ $rx_bytes -ge 932 && $tx_bytes -ge 1516 && $rx_bytes -lt 2500 && $rx_bytes -lt 2500 ]]
+tests
+ip1 link set wg1 mtu $big_mtu
+ip2 link set wg2 mtu $big_mtu
+tests
+
+ip1 link set wg1 mtu $orig_mtu
+ip2 link set wg2 mtu $orig_mtu
+
+# Test using IPv6 as outer transport
+n0 wg set wg1 peer "$pub2" endpoint [::1]:20000
+n0 wg set wg2 peer "$pub1" endpoint [::1]:10000
+tests
+ip1 link set wg1 mtu $big_mtu
+ip2 link set wg2 mtu $big_mtu
+tests
+
+ip1 link set wg1 mtu $orig_mtu
+ip2 link set wg2 mtu $orig_mtu
+
+# Test using IPv4 that roaming works
+ip0 -4 addr del 127.0.0.1/8 dev lo
+ip0 -4 addr add 127.212.121.99/8 dev lo
+n0 wg set wg1 listen-port 9999
+n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
+n1 ping6 -W 1 -c 1 fd00::20000
+[[ $(n2 wg show wg2 endpoints) == "$pub1    127.212.121.99:9999" ]]
+
+# Test using IPv6 that roaming works
+n1 wg set wg1 listen-port 9998
+n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
+n1 ping -W 1 -c 1 192.168.241.2
+[[ $(n2 wg show wg2 endpoints) == "$pub1    [::1]:9998" ]]
+
+# Test that crypto-RP filter works
+n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
+exec 4< <(n1 ncat -l -u -p 1111)
+nmap_pid=$!
+waitncatudp $netns1
+n2 ncat -u 192.168.241.1 1111 <<<"X"
+read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]]
+kill $nmap_pid
+more_specific_key="$(pp wg genkey | pp wg pubkey)"
+n0 wg set wg1 peer "$more_specific_key" allowed-ips 192.168.241.2/32
+n0 wg set wg2 listen-port 9997
+exec 4< <(n1 ncat -l -u -p 1111)
+nmap_pid=$!
+waitncatudp $netns1
+n2 ncat -u 192.168.241.1 1111 <<<"X"
+! read -r -N 1 -t 1 out <&4
+kill $nmap_pid
+n0 wg set wg1 peer "$more_specific_key" remove
+[[ $(n1 wg show wg1 endpoints) == "$pub2    [::1]:9997" ]]
+
+ip1 link del wg1
+ip2 link del wg2
+
+# Test using NAT. We now change the topology to this:
+# ┌────────────────────────────────────────┐    ┌────────────────────────────────────────────────┐     ┌────────────────────────────────────────┐
+# │             $ns1 namespace             │    │                 $ns0 namespace                 │     │             $ns2 namespace             │
+# │                                        │    │                                                │     │                                        │
+# │  ┌─────┐             ┌─────┐           │    │    ┌──────┐              ┌──────┐              │     │  ┌─────┐            ┌─────┐            │
+# │  │ wg1 │─────────────│vethc│───────────┼────┼────│vethrc│              │vethrs│──────────────┼─────┼──│veths│────────────│ wg2 │            │
+# │  ├─────┴──────────┐  ├─────┴──────────┐│    │    ├──────┴─────────┐    ├──────┴────────────┐ │     │  ├─────┴──────────┐ ├─────┴──────────┐ │
+# │  │192.168.241.1/24│  │192.168.1.100/24││    │    │192.168.1.100/24│    │10.0.0.1/24        │ │     │  │10.0.0.100/24   │ │192.168.241.2/24│ │
+# │  │fd00::1/24      │  │                ││    │    │                │    │SNAT:192.168.1.0/24│ │     │  │                │ │fd00::2/24      │ │
+# │  └────────────────┘  └────────────────┘│    │    └────────────────┘    └───────────────────┘ │     │  └────────────────┘ └────────────────┘ │
+# └────────────────────────────────────────┘    └────────────────────────────────────────────────┘     └────────────────────────────────────────┘
+
+# ip1 link add dev wg1 type wireguard
+# ip2 link add dev wg1 type wireguard
+
+n1 $program wg1
+n2 $program wg2
+
+configure_peers
+
+ip0 link add vethrc type veth peer name vethc
+ip0 link add vethrs type veth peer name veths
+ip0 link set vethc netns $netns1
+ip0 link set veths netns $netns2
+ip0 link set vethrc up
+ip0 link set vethrs up
+ip0 addr add 192.168.1.1/24 dev vethrc
+ip0 addr add 10.0.0.1/24 dev vethrs
+ip1 addr add 192.168.1.100/24 dev vethc
+ip1 link set vethc up
+ip1 route add default via 192.168.1.1
+ip2 addr add 10.0.0.100/24 dev veths
+ip2 link set veths up
+waitiface $netns0 vethrc
+waitiface $netns0 vethrs
+waitiface $netns1 vethc
+waitiface $netns2 veths
+
+n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward'
+n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout'
+n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream'
+n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1
+
+n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
+n1 ping -W 1 -c 1 192.168.241.2
+n2 ping -W 1 -c 1 192.168.241.1
+[[ $(n2 wg show wg2 endpoints) == "$pub1    10.0.0.1:10000" ]]
+# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
+pp sleep 3
+n2 ping -W 1 -c 1 192.168.241.1
+
+n0 iptables -t nat -F
+ip0 link del vethrc
+ip0 link del vethrs
+ip1 link del wg1
+ip2 link del wg2
+
+# Test that saddr routing is sticky but not too sticky, changing to this topology:
+# ┌────────────────────────────────────────┐    ┌────────────────────────────────────────┐
+# │             $ns1 namespace             │    │             $ns2 namespace             │
+# │                                        │    │                                        │
+# │  ┌─────┐             ┌─────┐           │    │  ┌─────┐            ┌─────┐            │
+# │  │ wg1 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg2 │            │
+# │  ├─────┴──────────┐  ├─────┴──────────┐│    │  ├─────┴──────────┐ ├─────┴──────────┐ │
+# │  │192.168.241.1/24│  │10.0.0.1/24     ││    │  │10.0.0.2/24     │ │192.168.241.2/24│ │
+# │  │fd00::1/24      │  │fd00:aa::1/96   ││    │  │fd00:aa::2/96   │ │fd00::2/24      │ │
+# │  └────────────────┘  └────────────────┘│    │  └────────────────┘ └────────────────┘ │
+# └────────────────────────────────────────┘    └────────────────────────────────────────┘
+
+# ip1 link add dev wg1 type wireguard
+# ip2 link add dev wg1 type wireguard
+n1 $program wg1
+n2 $program wg1
+
+configure_peers
+
+ip1 link add veth1 type veth peer name veth2
+ip1 link set veth2 netns $netns2
+n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad'
+n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad'
+n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries'
+
+# First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed
+ip1 addr add 10.0.0.1/24 dev veth1
+ip1 addr add fd00:aa::1/96 dev veth1
+ip2 addr add 10.0.0.2/24 dev veth2
+ip2 addr add fd00:aa::2/96 dev veth2
+ip1 link set veth1 up
+ip2 link set veth2 up
+waitiface $netns1 veth1
+waitiface $netns2 veth2
+n0 wg set wg1 peer "$pub2" endpoint 10.0.0.2:20000
+n1 ping -W 1 -c 1 192.168.241.2
+ip1 addr add 10.0.0.10/24 dev veth1
+ip1 addr del 10.0.0.1/24 dev veth1
+n1 ping -W 1 -c 1 192.168.241.2
+n0 wg set wg1 peer "$pub2" endpoint [fd00:aa::2]:20000
+n1 ping -W 1 -c 1 192.168.241.2
+ip1 addr add fd00:aa::10/96 dev veth1
+ip1 addr del fd00:aa::1/96 dev veth1
+n1 ping -W 1 -c 1 192.168.241.2
+
+# Now we show that we can successfully do reply to sender routing
+ip1 link set veth1 down
+ip2 link set veth2 down
+ip1 addr flush dev veth1
+ip2 addr flush dev veth2
+ip1 addr add 10.0.0.1/24 dev veth1
+ip1 addr add 10.0.0.2/24 dev veth1
+ip1 addr add fd00:aa::1/96 dev veth1
+ip1 addr add fd00:aa::2/96 dev veth1
+ip2 addr add 10.0.0.3/24 dev veth2
+ip2 addr add fd00:aa::3/96 dev veth2
+ip1 link set veth1 up
+ip2 link set veth2 up
+waitiface $netns1 veth1
+waitiface $netns2 veth2
+n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
+n2 ping -W 1 -c 1 192.168.241.1
+[[ $(n0 wg show wg2 endpoints) == "$pub1    10.0.0.1:10000" ]]
+n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
+n2 ping -W 1 -c 1 192.168.241.1
+[[ $(n0 wg show wg2 endpoints) == "$pub1    [fd00:aa::1]:10000" ]]
+n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
+n2 ping -W 1 -c 1 192.168.241.1
+[[ $(n0 wg show wg2 endpoints) == "$pub1    10.0.0.2:10000" ]]
+n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
+n2 ping -W 1 -c 1 192.168.241.1
+[[ $(n0 wg show wg2 endpoints) == "$pub1    [fd00:aa::2]:10000" ]]
+
+ip1 link del veth1
+ip1 link del wg1
+ip2 link del wg2
index aa96a8adcbc50c69b27d048d377e0d5bee002345..38fcd4a6e139a6faacd9650e7846d2220b8e2a85 100644 (file)
@@ -38,7 +38,7 @@ type Trie struct {
  */
 func commonBits(ip1 []byte, ip2 []byte) uint {
        var i uint
-       size := uint(len(ip1)) / 4
+       size := uint(len(ip1))
 
        for i = 0; i < size; i++ {
                v := ip1[i] ^ ip2[i]
index fd56b5a72364d9ce76a933de8244255ca0fc7287..b5dd663c867602a480eb18c08fad75b6e74752c6 100644 (file)
@@ -44,7 +44,12 @@ func (l *UAPIListener) Accept() (net.Conn, error) {
 }
 
 func (l *UAPIListener) Close() error {
-       return l.listener.Close()
+       err1 := unix.Close(l.inotifyFd)
+       err2 := l.listener.Close()
+       if err1 != nil {
+               return err1
+       }
+       return err2
 }
 
 func (l *UAPIListener) Addr() net.Addr {