]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Clear src cache if route changes to new ifindex
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 27 Apr 2018 03:21:45 +0000 (05:21 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 27 Apr 2018 03:41:07 +0000 (05:41 +0200)
conn_linux.go
tun_linux.go

index 88b9ef438795faae624b7a9887fdd1a2c82ce59b..ff3c4839e02f3485f78eaca51aae9f3dd4fd6aed 100644 (file)
@@ -53,12 +53,15 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
 }
 
 type NativeBind struct {
-       sock4 int
-       sock6 int
+       sock4        int
+       sock6        int
+       netlinkSock  int
+       lastEndpoint *NativeEndpoint
+       lastMark     uint32
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = NativeBind{}
+var _ Bind = (*NativeBind)(nil)
 
 func CreateEndpoint(s string) (Endpoint, error) {
        var end NativeEndpoint
@@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) {
        return nil, errors.New("Invalid IP address")
 }
 
-func CreateBind(port uint16) (Bind, uint16, error) {
+func createNetlinkRouteSocket() (int, error) {
+       sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+       if err != nil {
+               return -1, err
+       }
+       saddr := &unix.SockaddrNetlink{
+               Family: unix.AF_NETLINK,
+               Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+       }
+       err = unix.Bind(sock, saddr)
+       if err != nil {
+               unix.Close(sock)
+               return -1, err
+       }
+       return sock, nil
+
+}
+
+func CreateBind(port uint16) (*NativeBind, uint16, error) {
        var err error
        var bind NativeBind
 
+       bind.netlinkSock, err = createNetlinkRouteSocket()
+       if err != nil {
+               return nil, 0, err
+       }
+
+       go bind.routineRouteListener()
+
        bind.sock6, port, err = create6(port)
        if err != nil {
+               unix.Close(bind.netlinkSock)
                return nil, port, err
        }
 
        bind.sock4, port, err = create4(port)
        if err != nil {
+               unix.Close(bind.netlinkSock)
                unix.Close(bind.sock6)
        }
-       return bind, port, err
+       return &bind, port, err
 }
 
-func (bind NativeBind) SetMark(value uint32) error {
+func (bind *NativeBind) SetMark(value uint32) error {
        err := unix.SetsockoptInt(
                bind.sock6,
                unix.SOL_SOCKET,
@@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error {
                return err
        }
 
-       return unix.SetsockoptInt(
+       err = unix.SetsockoptInt(
                bind.sock4,
                unix.SOL_SOCKET,
                unix.SO_MARK,
                int(value),
        )
+
+       if err != nil {
+               return err
+       }
+
+       bind.lastMark = value
+       return nil
 }
 
 func closeUnblock(fd int) error {
@@ -137,16 +174,20 @@ func closeUnblock(fd int) error {
        return unix.Close(fd)
 }
 
-func (bind NativeBind) Close() error {
+func (bind *NativeBind) Close() error {
        err1 := closeUnblock(bind.sock6)
        err2 := closeUnblock(bind.sock4)
+       err3 := closeUnblock(bind.netlinkSock)
        if err1 != nil {
                return err1
        }
-       return err2
+       if err2 != nil {
+               return err2
+       }
+       return err3
 }
 
-func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
        n, err := receive6(
                bind.sock6,
@@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        return n, &end, err
 }
 
-func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
        n, err := receive4(
                bind.sock4,
                buff,
                &end,
        )
+       bind.lastEndpoint = &end
        return n, &end, err
 }
 
-func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
        nend := end.(*NativeEndpoint)
        if !nend.isV6 {
                return send4(bind.sock4, nend, buff)
@@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
        return size, nil
 }
+
+func (bind *NativeBind) routineRouteListener() {
+       // TODO: this function doesn't lock the endpoint it modifies
+
+       for msg := make([]byte, 1<<16); ; {
+               msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+               if err != nil {
+                       return
+               }
+
+               for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+                       hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+                       if uint(hdr.Len) > uint(len(remain)) {
+                               break
+                       }
+
+                       switch hdr.Type {
+                       case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+
+                               if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
+                                       break
+                               }
+
+                               if hdr.Seq == 0xff {
+                                       if uint(len(remain)) < uint(hdr.Len) {
+                                               break
+                                       }
+                                       if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+                                               attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+                                               for {
+                                                       if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+                                                               break
+                                                       }
+                                                       attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+                                                       if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+                                                               break
+                                                       }
+                                                       if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+                                                               ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+                                                               if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
+                                                                       bind.lastEndpoint.ClearSrc()
+                                                               }
+                                                       }
+                                                       attr = attr[attrhdr.Len:]
+                                               }
+                                       }
+                                       break
+                               }
+
+                               nlmsg := struct {
+                                       hdr     unix.NlMsghdr
+                                       msg     unix.RtMsg
+                                       dsthdr  unix.RtAttr
+                                       dst     [4]byte
+                                       srchdr  unix.RtAttr
+                                       src     [4]byte
+                                       markhdr unix.RtAttr
+                                       mark    uint32
+                               }{
+                                       unix.NlMsghdr{
+                                               Type:  uint16(unix.RTM_GETROUTE),
+                                               Flags: unix.NLM_F_REQUEST,
+                                               Seq:   0xff,
+                                       },
+                                       unix.RtMsg{
+                                               Family:  unix.AF_INET,
+                                               Dst_len: 32,
+                                               Src_len: 32,
+                                       },
+                                       unix.RtAttr{
+                                               Len:  8,
+                                               Type: unix.RTA_DST,
+                                       },
+                                       bind.lastEndpoint.dst4().Addr,
+                                       unix.RtAttr{
+                                               Len:  8,
+                                               Type: unix.RTA_SRC,
+                                       },
+                                       bind.lastEndpoint.src4().src,
+                                       unix.RtAttr{
+                                               Len:  8,
+                                               Type: 0x10, //unix.RTA_MARK  TODO: add this to x/sys/unix
+                                       },
+                                       uint32(bind.lastMark),
+                               }
+                               nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+                               unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                       }
+                       remain = remain[hdr.Len:]
+               }
+       }
+}
index 0672b5e2227a51f7f2a0ea961245ba88a22a6c95..b0ffa00741e45a98e993d52b999ef816b5396eab 100644 (file)
@@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() {
        defer unix.Close(sock)
        saddr := &unix.SockaddrNetlink{
                Family: unix.AF_NETLINK,
-               Pid:    uint32(os.Getpid()),
                Groups: uint32(groups),
        }
        err = unix.Bind(sock, saddr)
@@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() {
 
        // TODO: This function never actually exits in response to anything,
        // a go routine that goes forever. We'll want to fix that if this is
-       // to ever be used as any sort of library.
+       // to ever be used as any sort of library. See what we've done with
+       // calling shutdown() on the netlink socket in conn_linux.go, and
+       // change this to be more like that.
 
        for msg := make([]byte, 1<<16); ; {