]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Fixed port endianness
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 11 Nov 2017 14:43:55 +0000 (15:43 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 11 Nov 2017 14:43:55 +0000 (15:43 +0100)
src/conn.go
src/conn_linux.go
src/device.go
src/receive.go

index b2caffb98e489bb539c10f19f9f739465655082b..aa0b72bb15cc2bd0f7348a5349bfc240192a86f3 100644 (file)
@@ -34,6 +34,21 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
        return addr, err
 }
 
+/* Must hold device and net lock
+ */
+func unsafeCloseUDPListener(device *Device) error {
+       netc := &device.net
+       if netc.bind != nil {
+               if err := netc.bind.Close(); err != nil {
+                       return err
+               }
+               netc.bind = nil
+               netc.update.Broadcast()
+       }
+       return nil
+}
+
+// must inform all listeners
 func UpdateUDPListener(device *Device) error {
        device.mutex.Lock()
        defer device.mutex.Unlock()
@@ -44,26 +59,22 @@ func UpdateUDPListener(device *Device) error {
 
        // close existing sockets
 
-       if netc.bind != nil {
-               println("close bind")
-               if err := netc.bind.Close(); err != nil {
-                       return err
-               }
-               netc.bind = nil
-               println("closed")
+       if err := unsafeCloseUDPListener(device); err != nil {
+               return err
        }
 
+       // wait for reader
+
        // open new sockets
 
        if device.tun.isUp.Get() {
 
-               println("creat")
-
                // bind to new port
 
                var err error
                netc.bind, netc.port, err = CreateUDPBind(netc.port)
                if err != nil {
+                       netc.bind = nil
                        return err
                }
 
@@ -74,8 +85,6 @@ func UpdateUDPListener(device *Device) error {
                        return err
                }
 
-               println("okay")
-
                // clear cached source addresses
 
                for _, peer := range device.peers {
@@ -83,14 +92,20 @@ func UpdateUDPListener(device *Device) error {
                        peer.endpoint.value.ClearSrc()
                        peer.mutex.Unlock()
                }
+
+               // inform readers of updated bind
+
+               netc.update.Broadcast()
        }
 
        return nil
 }
 
 func CloseUDPListener(device *Device) error {
-       netc := &device.net
-       netc.mutex.Lock()
-       defer netc.mutex.Unlock()
-       return netc.bind.Close()
+       device.mutex.Lock()
+       device.net.mutex.Lock()
+       err := unsafeCloseUDPListener(device)
+       device.net.mutex.Unlock()
+       device.mutex.Unlock()
+       return err
 }
index 8cda460af3c01ddf468c3023c37b382c497b9c13..05f93470a57ea054362e99bbd3e4b909e4bb2e8a 100644 (file)
@@ -7,8 +7,8 @@
 package main
 
 import (
+       "encoding/binary"
        "errors"
-       "fmt"
        "golang.org/x/sys/unix"
        "net"
        "strconv"
@@ -37,6 +37,17 @@ type NativeBind struct {
        sock6 int
 }
 
+func htons(val uint16) uint16 {
+       var out [unsafe.Sizeof(val)]byte
+       binary.BigEndian.PutUint16(out[:], val)
+       return *((*uint16)(unsafe.Pointer(&out[0])))
+}
+
+func ntohs(val uint16) uint16 {
+       tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
+       return binary.BigEndian.Uint16((*tmp)[:])
+}
+
 func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
        var err error
        var bind NativeBind
@@ -50,8 +61,6 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
        if err != nil {
                unix.Close(bind.sock6)
        }
-       println(bind.sock6)
-       println(bind.sock4)
        return bind, port, err
 }
 
@@ -297,13 +306,11 @@ func (end *Endpoint) SetDst(s string) error {
                return err
        }
 
-       fmt.Println(addr, err)
-
        ipv4 := addr.IP.To4()
        if ipv4 != nil {
                dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
                dst.Family = unix.AF_INET
-               dst.Port = uint16(addr.Port)
+               dst.Port = htons(uint16(addr.Port))
                dst.Zero = [8]byte{}
                copy(dst.Addr[:], ipv4)
                end.ClearSrc()
@@ -318,7 +325,7 @@ func (end *Endpoint) SetDst(s string) error {
                }
                dst := &end.dst
                dst.Family = unix.AF_INET6
-               dst.Port = uint16(addr.Port)
+               dst.Port = htons(uint16(addr.Port))
                dst.Flowinfo = 0
                dst.Scope_id = zone
                copy(dst.Addr[:], ipv6[:])
@@ -392,9 +399,6 @@ func send6(sock int, end *Endpoint, buff []byte) error {
 }
 
 func send4(sock int, end *Endpoint, buff []byte) error {
-       println("send 4")
-       println(end.DstToString())
-       println(sock)
 
        // construct message header
 
@@ -425,6 +429,7 @@ func send4(sock int, end *Endpoint, buff []byte) error {
                Name:    (*byte)(unsafe.Pointer(&end.dst)),
                Namelen: unix.SizeofSockaddrInet4,
                Control: (*byte)(unsafe.Pointer(&cmsg)),
+               Flags:   0,
        }
        msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
 
@@ -437,10 +442,6 @@ func send4(sock int, end *Endpoint, buff []byte) error {
                0,
        )
 
-       if errno == 0 {
-               return nil
-       }
-
        // clear source and try again
 
        if errno == unix.EINVAL {
@@ -454,6 +455,12 @@ func send4(sock int, end *Endpoint, buff []byte) error {
                )
        }
 
+       // errno = 0 is still an error instance
+
+       if errno == 0 {
+               return nil
+       }
+
        return errno
 }
 
index 1aae4488351c9ef641a81c09d7c0064e50c5a44a..a348c682f611d6026e784350202cbf596b702e04 100644 (file)
@@ -23,9 +23,10 @@ type Device struct {
        }
        net struct {
                mutex  sync.RWMutex
-               bind   UDPBind
-               port   uint16
-               fwmark uint32
+               bind   UDPBind    // bind interface
+               port   uint16     // listening port
+               fwmark uint32     // mark value (0 = disabled)
+               update *sync.Cond // the bind was updated
        }
        mutex        sync.RWMutex
        privateKey   NoisePrivateKey
@@ -38,8 +39,7 @@ type Device struct {
                handshake  chan QueueHandshakeElement
        }
        signal struct {
-               stop       chan struct{}
-               updateBind chan struct{}
+               stop chan struct{}
        }
        underLoadUntil atomic.Value
        ratelimiter    Ratelimiter
@@ -163,6 +163,12 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 
        device.signal.stop = make(chan struct{})
 
+       // prepare net
+
+       device.net.port = 0
+       device.net.bind = nil
+       device.net.update = sync.NewCond(&device.net.mutex)
+
        // start workers
 
        for i := 0; i < runtime.NumCPU(); i += 1 {
index 1f05b2fa064de5d35665d131359a7cd77df64e5d..cb53f804f0a2c7f3fe644449f1abb8f0378c0543 100644 (file)
@@ -99,135 +99,126 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 
        for {
 
-               // wait for new conn
-
-               logDebug.Println("Waiting for udp socket")
-
-               select {
-               case <-device.signal.stop:
-                       return
-
-               case <-device.signal.updateBind:
-
-                       // fetch new socket
+               // wait for bind
+
+               logDebug.Println("Waiting for udp bind")
+               device.net.mutex.Lock()
+               device.net.update.Wait()
+               bind := device.net.bind
+               device.net.mutex.Unlock()
+               if bind == nil {
+                       continue
+               }
 
-                       device.net.mutex.RLock()
-                       bind := device.net.bind
-                       device.net.mutex.RUnlock()
-                       if bind == nil {
-                               continue
-                       }
+               logDebug.Println("LISTEN\n\n\n")
 
-                       logDebug.Println("Listening for inbound packets")
+               // receive datagrams until conn is closed
 
-                       // receive datagrams until conn is closed
+               buffer := device.GetMessageBuffer()
 
-                       buffer := device.GetMessageBuffer()
+               var size int
+               var err error
 
-                       var size int
-                       var err error
+               for {
 
-                       for {
+                       // read next datagram
 
-                               // read next datagram
+                       var endpoint Endpoint
 
-                               var endpoint Endpoint
-
-                               switch IPVersion {
-                               case ipv4.Version:
-                                       size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
-                               case ipv6.Version:
-                                       size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
-                               default:
-                                       return
-                               }
+                       switch IPVersion {
+                       case ipv4.Version:
+                               size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
+                       case ipv6.Version:
+                               size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
+                       default:
+                               return
+                       }
 
-                               if err != nil {
-                                       break
-                               }
+                       if err != nil {
+                               break
+                       }
 
-                               if size < MinMessageSize {
-                                       continue
-                               }
+                       if size < MinMessageSize {
+                               continue
+                       }
 
-                               // check size of packet
+                       // check size of packet
 
-                               packet := buffer[:size]
-                               msgType := binary.LittleEndian.Uint32(packet[:4])
+                       packet := buffer[:size]
+                       msgType := binary.LittleEndian.Uint32(packet[:4])
 
-                               var okay bool
+                       var okay bool
 
-                               switch msgType {
+                       switch msgType {
 
-                               // check if transport
+                       // check if transport
 
-                               case MessageTransportType:
+                       case MessageTransportType:
 
-                                       // check size
+                               // check size
 
-                                       if len(packet) < MessageTransportType {
-                                               continue
-                                       }
+                               if len(packet) < MessageTransportType {
+                                       continue
+                               }
 
-                                       // lookup key pair
+                               // lookup key pair
 
-                                       receiver := binary.LittleEndian.Uint32(
-                                               packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
-                                       )
-                                       value := device.indices.Lookup(receiver)
-                                       keyPair := value.keyPair
-                                       if keyPair == nil {
-                                               continue
-                                       }
+                               receiver := binary.LittleEndian.Uint32(
+                                       packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+                               )
+                               value := device.indices.Lookup(receiver)
+                               keyPair := value.keyPair
+                               if keyPair == nil {
+                                       continue
+                               }
 
-                                       // check key-pair expiry
+                               // check key-pair expiry
 
-                                       if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
-                                               continue
-                                       }
+                               if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+                                       continue
+                               }
 
-                                       // create work element
+                               // create work element
 
-                                       peer := value.peer
-                                       elem := &QueueInboundElement{
-                                               packet:  packet,
-                                               buffer:  buffer,
-                                               keyPair: keyPair,
-                                               dropped: AtomicFalse,
-                                       }
-                                       elem.mutex.Lock()
+                               peer := value.peer
+                               elem := &QueueInboundElement{
+                                       packet:  packet,
+                                       buffer:  buffer,
+                                       keyPair: keyPair,
+                                       dropped: AtomicFalse,
+                               }
+                               elem.mutex.Lock()
 
-                                       // add to decryption queues
+                               // add to decryption queues
 
-                                       device.addToDecryptionQueue(device.queue.decryption, elem)
-                                       device.addToInboundQueue(peer.queue.inbound, elem)
-                                       buffer = device.GetMessageBuffer()
-                                       continue
+                               device.addToDecryptionQueue(device.queue.decryption, elem)
+                               device.addToInboundQueue(peer.queue.inbound, elem)
+                               buffer = device.GetMessageBuffer()
+                               continue
 
-                               // otherwise it is a fixed size & handshake related packet
+                       // otherwise it is a fixed size & handshake related packet
 
-                               case MessageInitiationType:
-                                       okay = len(packet) == MessageInitiationSize
+                       case MessageInitiationType:
+                               okay = len(packet) == MessageInitiationSize
 
-                               case MessageResponseType:
-                                       okay = len(packet) == MessageResponseSize
+                       case MessageResponseType:
+                               okay = len(packet) == MessageResponseSize
 
-                               case MessageCookieReplyType:
-                                       okay = len(packet) == MessageCookieReplySize
-                               }
+                       case MessageCookieReplyType:
+                               okay = len(packet) == MessageCookieReplySize
+                       }
 
-                               if okay {
-                                       device.addToHandshakeQueue(
-                                               device.queue.handshake,
-                                               QueueHandshakeElement{
-                                                       msgType:  msgType,
-                                                       buffer:   buffer,
-                                                       packet:   packet,
-                                                       endpoint: endpoint,
-                                               },
-                                       )
-                                       buffer = device.GetMessageBuffer()
-                               }
+                       if okay {
+                               device.addToHandshakeQueue(
+                                       device.queue.handshake,
+                                       QueueHandshakeElement{
+                                               msgType:  msgType,
+                                               buffer:   buffer,
+                                               packet:   packet,
+                                               endpoint: endpoint,
+                                       },
+                               )
+                               buffer = device.GetMessageBuffer()
                        }
                }
        }