]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Detects interface status on linux
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Wed, 16 Aug 2017 22:25:39 +0000 (00:25 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Wed, 16 Aug 2017 22:25:39 +0000 (00:25 +0200)
src/tun_linux.go

index 34f746a70b469960669988dd5e84f2690b345171..476a43f9b43347c3bb775be4834c81b2bad82773 100644 (file)
@@ -6,6 +6,7 @@ package main
 import (
        "encoding/binary"
        "errors"
+       "fmt"
        "golang.org/x/sys/unix"
        "net"
        "os"
@@ -13,12 +14,93 @@ import (
        "unsafe"
 )
 
-const CloneDevicePath = "/dev/net/tun"
+// #include <string.h>
+// #include <unistd.h>
+// #include <net/if.h>
+// #include <netinet/in.h>
+// #include <linux/netlink.h>
+// #include <linux/rtnetlink.h>
+//
+// /* Creates a netlink socket
+//  * listening to the RTMGRP_LINK multicast group
+//  */
+//
+// int bind_rtmgrp() {
+//   int nl_sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
+//   if (nl_sock < 0)
+//     return -1;
+//
+//      struct sockaddr_nl addr;
+//   memset ((void *) &addr, 0, sizeof (addr));
+//   addr.nl_family = AF_NETLINK;
+//   addr.nl_pid = getpid ();
+//   addr.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
+//
+//   if (bind(nl_sock, (struct sockaddr *) &addr, sizeof (addr)) < 0)
+//     return -1;
+//
+//   return nl_sock;
+// }
+import "C"
+
+const (
+       CloneDevicePath = "/dev/net/tun"
+       IFReqSize       = unix.IFNAMSIZ + 64
+)
 
 type NativeTun struct {
        fd     *os.File
+       index  int
        name   string
-       events chan TUNEvent
+       errors chan error    // async error handling
+       events chan TUNEvent //
+}
+
+func (tun *NativeTun) RoutineNetlinkListener() {
+       sock := int(C.bind_rtmgrp())
+       if sock < 0 {
+               tun.errors <- errors.New("Failed to create netlink event listener")
+               return
+       }
+
+       for msg := make([]byte, 1<<16); ; {
+
+               msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)
+               if err != nil {
+                       tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error())
+                       return
+               }
+
+               for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+                       hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+                       if int(hdr.Len) > len(remain) {
+                               break
+                       }
+
+                       switch hdr.Type {
+                       case unix.NLMSG_DONE:
+                               remain = []byte{}
+
+                       case unix.RTM_NEWLINK:
+                               info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
+
+                               if info.Flags&unix.IFF_RUNNING != 0 {
+                                       tun.events <- TUNEventUp
+                               }
+
+                               if info.Flags&unix.IFF_RUNNING == 0 {
+                                       tun.events <- TUNEventDown
+                               }
+
+                               remain = remain[hdr.Len:]
+
+                       default:
+                               remain = remain[hdr.Len:]
+                       }
+               }
+       }
 }
 
 func (tun *NativeTun) isUp() (bool, error) {
@@ -30,6 +112,46 @@ func (tun *NativeTun) Name() string {
        return tun.name
 }
 
+func toInt32(val []byte) int {
+       n := binary.LittleEndian.Uint32(val[:4])
+       if n >= (1 << 31) {
+               return int(n-(1<<31)) - (1 << 31)
+       }
+       return int(n)
+}
+
+func getDummySock() (int, error) {
+       return unix.Socket(
+               unix.AF_INET,
+               unix.SOCK_DGRAM,
+               0,
+       )
+}
+
+func getIFIndex(name string) (int, error) {
+       fd, err := getDummySock()
+       if err != nil {
+               return 0, err
+       }
+
+       defer unix.Close(fd)
+
+       var ifr [IFReqSize]byte
+       copy(ifr[:], name)
+       _, _, errno := unix.Syscall(
+               unix.SYS_IOCTL,
+               uintptr(fd),
+               uintptr(unix.SIOCGIFINDEX),
+               uintptr(unsafe.Pointer(&ifr[0])),
+       )
+
+       if errno != 0 {
+               return 0, errno
+       }
+
+       return toInt32(ifr[unix.IFNAMSIZ:]), nil
+}
+
 func (tun *NativeTun) setMTU(n int) error {
 
        // open datagram socket
@@ -48,7 +170,7 @@ func (tun *NativeTun) setMTU(n int) error {
 
        // do ioctl call
 
-       var ifr [64]byte
+       var ifr [IFReqSize]byte
        copy(ifr[:], tun.name)
        binary.LittleEndian.PutUint32(ifr[16:20], uint32(n))
        _, _, errno := unix.Syscall(
@@ -83,7 +205,7 @@ func (tun *NativeTun) MTU() (int, error) {
 
        // do ioctl call
 
-       var ifr [64]byte
+       var ifr [IFReqSize]byte
        copy(ifr[:], tun.name)
        _, _, errno := unix.Syscall(
                unix.SYS_IOCTL,
@@ -109,7 +231,12 @@ func (tun *NativeTun) Write(d []byte) (int, error) {
 }
 
 func (tun *NativeTun) Read(d []byte) (int, error) {
-       return tun.fd.Read(d)
+       select {
+       case err := <-tun.errors:
+               return 0, err
+       default:
+               return tun.fd.Read(d)
+       }
 }
 
 func (tun *NativeTun) Events() chan TUNEvent {
@@ -131,11 +258,11 @@ func CreateTUN(name string) (TUNDevice, error) {
 
        // create new device
 
-       var ifr [64]byte
+       var ifr [IFReqSize]byte
        var flags uint16 = unix.IFF_TUN | unix.IFF_NO_PI
        nameBytes := []byte(name)
        if len(nameBytes) >= unix.IFNAMSIZ {
-               return nil, errors.New("Name size too long")
+               return nil, errors.New("Interface name too long")
        }
        copy(ifr[:], nameBytes)
        binary.LittleEndian.PutUint16(ifr[16:], flags)
@@ -147,7 +274,7 @@ func CreateTUN(name string) (TUNDevice, error) {
                uintptr(unsafe.Pointer(&ifr[0])),
        )
        if errno != 0 {
-               return nil, errors.New("Failed to create tun, ioctl call failed")
+               return nil, errno
        }
 
        // read (new) name of interface
@@ -158,13 +285,21 @@ func CreateTUN(name string) (TUNDevice, error) {
                fd:     fd,
                name:   newName,
                events: make(chan TUNEvent, 5),
+               errors: make(chan error, 5),
        }
 
-       // TODO: Wait for device to be upped
-       device.events <- TUNEventUp
+       // fetch IF index
+
+       device.index, err = getIFIndex(device.name)
+       if err != nil {
+               return nil, err
+       }
+
+       go device.RoutineNetlinkListener()
 
        // set default MTU
 
-       err = device.setMTU(DefaultMTU)
-       return device, err
+       fmt.Println(device)
+
+       return device, device.setMTU(DefaultMTU)
 }