]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added missing IF index check
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 17 Aug 2017 10:58:18 +0000 (12:58 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 17 Aug 2017 10:58:18 +0000 (12:58 +0200)
src/conn.go
src/device.go
src/tun_linux.go

index f6472e93140fb8cc1ee058774cbac071b924bd63..e23b3506f25b552aeb1046c68a8a85899b6f687b 100644 (file)
@@ -5,9 +5,9 @@ import (
 )
 
 func updateUDPConn(device *Device) error {
-       var err error
        netc := &device.net
        netc.mutex.Lock()
+       defer netc.mutex.Unlock()
 
        // close existing connection
 
@@ -18,15 +18,23 @@ func updateUDPConn(device *Device) error {
        // open new connection
 
        if device.tun.isUp.Get() {
+
+               // listen on new address
+
                conn, err := net.ListenUDP("udp", netc.addr)
-               if err == nil {
-                       netc.conn = conn
-                       signalSend(device.signal.newUDPConn)
+               if err != nil {
+                       return err
                }
+
+               // retrieve port (may have been chosen by kernel)
+
+               addr := conn.LocalAddr()
+               netc.conn = conn
+               netc.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
+               signalSend(device.signal.newUDPConn)
        }
 
-       netc.mutex.Unlock()
-       return err
+       return nil
 }
 
 func closeUDPConn(device *Device) {
index dfd2f35541e97179b7ae9e55aa0544be9f943a76..9bcd2f58df8f64be17ddf9aeb5013a814d26dad3 100644 (file)
@@ -196,15 +196,19 @@ func (device *Device) RoutineTUNEventReader() {
                }
 
                if event&TUNEventUp != 0 {
-                       device.tun.isUp.Set(true)
-                       updateUDPConn(device)
-                       logInfo.Println("Interface set up")
+                       if !device.tun.isUp.Get() {
+                               device.tun.isUp.Set(true)
+                               updateUDPConn(device)
+                               logInfo.Println("Interface set up")
+                       }
                }
 
                if event&TUNEventDown != 0 {
-                       device.tun.isUp.Set(false)
-                       closeUDPConn(device)
-                       logInfo.Println("Interface set down")
+                       if device.tun.isUp.Get() {
+                               device.tun.isUp.Set(false)
+                               closeUDPConn(device)
+                               logInfo.Println("Interface set down")
+                       }
                }
        }
 }
index 476a43f9b43347c3bb775be4834c81b2bad82773..e75273325628aaf9462f81dbac11528fc8f5b134 100644 (file)
@@ -50,10 +50,10 @@ const (
 
 type NativeTun struct {
        fd     *os.File
-       index  int
-       name   string
+       index  int32         // if index
+       name   string        // name of interface
        errors chan error    // async error handling
-       events chan TUNEvent //
+       events chan TUNEvent // device related events
 }
 
 func (tun *NativeTun) RoutineNetlinkListener() {
@@ -86,6 +86,11 @@ func (tun *NativeTun) RoutineNetlinkListener() {
                        case unix.RTM_NEWLINK:
                                info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
 
+                               if info.Index != tun.index {
+                                       // not our interface
+                                       continue
+                               }
+
                                if info.Flags&unix.IFF_RUNNING != 0 {
                                        tun.events <- TUNEventUp
                                }
@@ -112,12 +117,12 @@ func (tun *NativeTun) Name() string {
        return tun.name
 }
 
-func toInt32(val []byte) int {
+func toInt32(val []byte) int32 {
        n := binary.LittleEndian.Uint32(val[:4])
        if n >= (1 << 31) {
-               return int(n-(1<<31)) - (1 << 31)
+               return -int32(^n) - 1
        }
-       return int(n)
+       return int32(n)
 }
 
 func getDummySock() (int, error) {
@@ -128,7 +133,7 @@ func getDummySock() (int, error) {
        )
 }
 
-func getIFIndex(name string) (int, error) {
+func getIFIndex(name string) (int32, error) {
        fd, err := getDummySock()
        if err != nil {
                return 0, err
@@ -288,7 +293,7 @@ func CreateTUN(name string) (TUNDevice, error) {
                errors: make(chan error, 5),
        }
 
-       // fetch IF index
+       // start event listener
 
        device.index, err = getIFIndex(device.name)
        if err != nil {
@@ -299,7 +304,5 @@ func CreateTUN(name string) (TUNDevice, error) {
 
        // set default MTU
 
-       fmt.Println(device)
-
        return device, device.setMTU(DefaultMTU)
 }