]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Begin incorporating new src cache into receive
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 7 Oct 2017 20:35:23 +0000 (22:35 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 7 Oct 2017 20:35:23 +0000 (22:35 +0200)
src/conn.go
src/conn_linux.go
src/device.go
src/main.go
src/receive.go

index 60cd789c866b4c4ad145dda1a33c9c6ad1d79523..61be3bfc7b25166e7ef7b686b600a5fb16a3a52b 100644 (file)
@@ -3,7 +3,6 @@ package main
 import (
        "errors"
        "net"
-       "time"
 )
 
 func parseEndpoint(s string) (*net.UDPAddr, error) {
@@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
        return addr, err
 }
 
-func updateUDPConn(device *Device) error {
+func ListenerClose(l *Listener) (err error) {
+       if l.active {
+               err = CloseIPv4Socket(l.sock)
+               l.active = false
+       }
+       return
+}
+
+func (l *Listener) Init() {
+       l.update = make(chan struct{}, 1)
+       ListenerClose(l)
+}
+
+func ListeningUpdate(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
        defer netc.mutex.Unlock()
 
-       // close existing connection
+       // close existing sockets
 
-       if netc.conn != nil {
-               netc.conn.Close()
-               netc.conn = nil
+       if err := ListenerClose(&netc.ipv4); err != nil {
+               return err
+       }
 
-               // We need for that fd to be closed in all other go routines, which
-               // means we have to wait. TODO: find less horrible way of doing this.
-               time.Sleep(time.Second / 2)
+       if err := ListenerClose(&netc.ipv6); err != nil {
+               return err
        }
 
-       // open new connection
+       // open new sockets
 
        if device.tun.isUp.Get() {
 
-               // listen on new address
-
-               conn, err := net.ListenUDP("udp", netc.addr)
-               if err != nil {
-                       return err
+               // listen on IPv4
+
+               {
+                       list := &netc.ipv6
+                       sock, port, err := CreateIPv4Socket(netc.port)
+                       if err != nil {
+                               return err
+                       }
+                       netc.port = port
+                       list.sock = sock
+                       list.active = true
+
+                       if err := SetMark(list.sock, netc.fwmark); err != nil {
+                               ListenerClose(list)
+                               return err
+                       }
+                       signalSend(list.update)
                }
 
-               // set fwmark
-
-               err = SetMark(netc.conn, netc.fwmark)
-               if err != nil {
-                       return err
+               // listen on IPv6
+
+               {
+                       list := &netc.ipv6
+                       sock, port, err := CreateIPv6Socket(netc.port)
+                       if err != nil {
+                               return err
+                       }
+                       netc.port = port
+                       list.sock = sock
+                       list.active = true
+
+                       if err := SetMark(list.sock, netc.fwmark); err != nil {
+                               ListenerClose(list)
+                               return err
+                       }
+                       signalSend(list.update)
                }
 
-               // retrieve port (may have been chosen by kernel)
-
-               addr := conn.LocalAddr()
-               netc.conn = conn
-               netc.addr, _ = net.ResolveUDPAddr(
-                       addr.Network(),
-                       addr.String(),
-               )
-
-               // notify goroutines
-
-               signalSend(device.signal.newUDPConn)
+               // TODO: clear endpoint caches
        }
 
        return nil
 }
 
-func closeUDPConn(device *Device) {
+func ListeningClose(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
-       if netc.conn != nil {
-               netc.conn.Close()
+       defer netc.mutex.Unlock()
+
+       if err := ListenerClose(&netc.ipv4); err != nil {
+               return err
        }
-       netc.mutex.Unlock()
-       signalSend(device.signal.newUDPConn)
+       signalSend(netc.ipv4.update)
+
+       if err := ListenerClose(&netc.ipv6); err != nil {
+               return err
+       }
+       signalSend(netc.ipv6.update)
+
+       return nil
 }
index 64447a547800503bac49a9927dbd0cdca2dc081d..034fb8bfe616f5174be754d8198e86c966732472 100644 (file)
@@ -28,6 +28,7 @@ import "fmt"
 type Endpoint struct {
        // source (selected based on dst type)
        // (could use RawSockaddrAny and unsafe)
+       // TODO: Merge
        src6   unix.RawSockaddrInet6
        src4   unix.RawSockaddrInet4
        src4if int32
@@ -35,8 +36,14 @@ type Endpoint struct {
        dst unix.RawSockaddrAny
 }
 
-type IPv4Socket int
-type IPv6Socket int
+type Socket int
+
+/* Returns a byte representation of the source field(s)
+ * for use in "under load" cookie computations.
+ */
+func (endpoint *Endpoint) Source() []byte {
+       return nil
+}
 
 func zoneToUint32(zone string) (uint32, error) {
        if zone == "" {
@@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) {
        return uint32(n), err
 }
 
-func CreateIPv4Socket(port int) (IPv4Socket, error) {
+func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
 
        // create socket
 
@@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
        )
 
        if err != nil {
-               return -1, err
+               return -1, 0, err
+       }
+
+       addr := unix.SockaddrInet4{
+               Port: int(port),
        }
 
        // set sockopts and bind
 
        if err := func() error {
-
                if err := unix.SetsockoptInt(
                        fd,
                        unix.SOL_SOCKET,
@@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
                        return err
                }
 
-               addr := unix.SockaddrInet4{
-                       Port: port,
-               }
                return unix.Bind(fd, &addr)
-
        }(); err != nil {
                unix.Close(fd)
        }
 
-       return IPv4Socket(fd), err
+       return Socket(fd), uint16(addr.Port), err
 }
 
-func CreateIPv6Socket(port int) (IPv6Socket, error) {
+func CloseIPv4Socket(sock Socket) error {
+       return unix.Close(int(sock))
+}
+
+func CloseIPv6Socket(sock Socket) error {
+       return unix.Close(int(sock))
+}
+
+func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
 
        // create socket
 
@@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
        )
 
        if err != nil {
-               return -1, err
+               return -1, 0, err
        }
 
        // set sockopts and bind
 
+       addr := unix.SockaddrInet6{
+               Port: int(port),
+       }
+
        if err := func() error {
 
                if err := unix.SetsockoptInt(
@@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
                        return err
                }
 
-               addr := unix.SockaddrInet6{
-                       Port: port,
-               }
                return unix.Bind(fd, &addr)
 
        }(); err != nil {
                unix.Close(fd)
        }
 
-       return IPv6Socket(fd), err
+       return Socket(fd), uint16(addr.Port), err
 }
 
 func (end *Endpoint) ClearSrc() {
@@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
        return errors.New("Unknown address family of source")
 }
 
-func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
+func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
 
        // contruct message header
 
@@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
        return int(size), nil
 }
 
-func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
+func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
 
        // contruct message header
 
@@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
 
        // recvmsg(sock, &mskhdr, 0)
 
-       _, _, errno := unix.Syscall(
+       size, _, errno := unix.Syscall(
                unix.SYS_RECVMSG,
                uintptr(sock),
                uintptr(unsafe.Pointer(&msg)),
@@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
        )
 
        if errno != 0 {
-               return errno
+               return 0, errno
        }
 
        // update source cache
@@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
                end.src6.Scope_id = cmsg.pktinfo.Ifindex
        }
 
-       return nil
+       return int(size), nil
 }
 
-func SetMark(conn *net.UDPConn, value uint32) error {
-       if conn == nil {
-               return nil
-       }
-
-       file, err := conn.File()
-       if err != nil {
-               return err
-       }
-
+func SetMark(sock Socket, value uint32) error {
        return unix.SetsockoptInt(
-               int(file.Fd()),
+               int(sock),
                unix.SOL_SOCKET,
                unix.SO_MARK,
                int(value),
index 61c87bc99d7fda370384906c7e76a6fcbe67781c..509e6a741bc632010fe9eddc72a57d3484001339 100644 (file)
@@ -1,13 +1,18 @@
 package main
 
 import (
-       "net"
        "runtime"
        "sync"
        "sync/atomic"
        "time"
 )
 
+type Listener struct {
+       sock   Socket
+       active bool
+       update chan struct{}
+}
+
 type Device struct {
        log       *Logger // collection of loggers for levels
        idCounter uint    // for assigning debug ids to peers
@@ -22,8 +27,9 @@ type Device struct {
        }
        net struct {
                mutex  sync.RWMutex
-               addr   *net.UDPAddr // UDP source address
-               conn   *net.UDPConn // UDP "connection"
+               ipv4   Listener
+               ipv6   Listener
+               port   uint16
                fwmark uint32
        }
        mutex        sync.RWMutex
@@ -37,8 +43,9 @@ type Device struct {
                handshake  chan QueueHandshakeElement
        }
        signal struct {
-               stop       chan struct{} // halts all go routines
-               newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
+               stop             chan struct{} // halts all go routines
+               updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
+               updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
        }
        underLoadUntil atomic.Value
        ratelimiter    Ratelimiter
@@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        device.log = NewLogger(logLevel, "("+tun.Name()+") ")
        device.peers = make(map[NoisePublicKey]*Peer)
        device.tun.device = tun
+
        device.indices.Init()
+       device.net.ipv4.Init()
+       device.net.ipv6.Init()
        device.ratelimiter.Init()
+
        device.routingTable.Reset()
        device.underLoadUntil.Store(time.Time{})
 
-       // setup pools
+       // setup buffer pool
 
        device.pool.messageBuffers = sync.Pool{
                New: func() interface{} {
@@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        // prepare signals
 
        device.signal.stop = make(chan struct{})
-       device.signal.newUDPConn = make(chan struct{}, 1)
 
        // start workers
 
@@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
                go device.RoutineDecryption()
                go device.RoutineHandshake()
        }
-
+       go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
-       go device.RoutineReadFromTUN()
-       go device.RoutineReceiveIncomming()
-
+       go device.RoutineReceiveIncomming(&device.net.ipv4)
+       go device.RoutineReceiveIncomming(&device.net.ipv6)
        return device
 }
 
@@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) Close() {
        device.RemoveAllPeers()
        close(device.signal.stop)
-       closeUDPConn(device)
+       ListeningClose(device)
 }
 
 func (device *Device) WaitChannel() chan struct{} {
index 196a4c607add27a7f06b0a8ad3e2b67772c754b6..a05dbba238fef3384cfc111faf09c06370e19336 100644 (file)
@@ -14,6 +14,7 @@ func printUsage() {
 }
 
 func main() {
+       test()
 
        // parse arguments
 
index 52c271803437a2c34a125b65097ad7a4aebf4086..60c0f2c7137373111ea37817261194a1816dfcb8 100644 (file)
@@ -13,10 +13,10 @@ import (
 )
 
 type QueueHandshakeElement struct {
-       msgType uint32
-       packet  []byte
-       buffer  *[MaxMessageSize]byte
-       source  *net.UDPAddr
+       msgType  uint32
+       packet   []byte
+       endpoint Endpoint
+       buffer   *[MaxMessageSize]byte
 }
 
 type QueueInboundElement struct {
@@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue(
        }
 }
 
-func (device *Device) RoutineReceiveIncomming() {
+func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 
        logDebug := device.log.Debug
        logDebug.Println("Routine, receive incomming, started")
 
+       var listener *Listener
+
+       switch IPVersion {
+       case ipv4.Version:
+               listener = &device.net.ipv4
+       case ipv6.Version:
+               listener = &device.net.ipv6
+       default:
+               return
+       }
+
        for {
 
                // wait for new conn
@@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() {
                case <-device.signal.stop:
                        return
 
-               case <-device.signal.newUDPConn:
+               case <-listener.update:
 
-                       // fetch connection
+                       // fetch new socket
 
                        device.net.mutex.RLock()
-                       conn := device.net.conn
+                       sock := listener.sock
+                       okay := listener.active
                        device.net.mutex.RUnlock()
-                       if conn == nil {
+                       if !okay {
                                continue
                        }
 
@@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() {
 
                        buffer := device.GetMessageBuffer()
 
+                       var size int
+                       var err error
+
                        for {
 
                                // read next datagram
 
-                               size, raddr, err := conn.ReadFromUDP(buffer[:])
+                               var endpoint Endpoint
+
+                               if IPVersion == ipv6.Version {
+                                       size, err = endpoint.ReceiveIPv4(sock, buffer[:])
+                               } else {
+                                       size, err = endpoint.ReceiveIPv6(sock, buffer[:])
+                               }
 
                                if err != nil {
                                        break
@@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() {
                                        buffer = device.GetMessageBuffer()
                                        continue
 
-                               // otherwise it is a handshake related packet
+                               // otherwise it is a fixed size & handshake related packet
 
                                case MessageInitiationType:
                                        okay = len(packet) == MessageInitiationSize
@@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() {
                                        device.addToHandshakeQueue(
                                                device.queue.handshake,
                                                QueueHandshakeElement{
-                                                       msgType: msgType,
-                                                       buffer:  buffer,
-                                                       packet:  packet,
-                                                       source:  raddr,
+                                                       msgType:  msgType,
+                                                       buffer:   buffer,
+                                                       packet:   packet,
+                                                       endpoint: endpoint,
                                                },
                                        )
                                        buffer = device.GetMessageBuffer()
@@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() {
 
                        // unmarshal packet
 
-                       logDebug.Println("Process cookie reply from:", elem.source.String())
-
                        var reply MessageCookieReply
                        reader := bytes.NewReader(elem.packet)
                        err := binary.Read(reader, binary.LittleEndian, &reply)