]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Look up route for every peer
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 01:00:40 +0000 (03:00 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 01:26:38 +0000 (03:26 +0200)
conn.go
conn_default.go
conn_linux.go

diff --git a/conn.go b/conn.go
index 4b347ec1b042dee039491e6f58d4617b493fb1c1..92f4cfe8808734a72f67a1db049afd9a5f59e474 100644 (file)
--- a/conn.go
+++ b/conn.go
@@ -123,7 +123,7 @@ func (device *Device) BindUpdate() error {
 
                var err error
                netc := &device.net
-               netc.bind, netc.port, err = CreateBind(netc.port)
+               netc.bind, netc.port, err = CreateBind(netc.port, device)
                if err != nil {
                        netc.bind = nil
                        netc.port = 0
index 047d5f6df275a29cf45af033d6a347b57b25618c..755621015d25a711ad799404e6836d50c80b7b46 100644 (file)
@@ -81,7 +81,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
        return conn, uaddr.Port, nil
 }
 
-func CreateBind(uport uint16) (Bind, uint16, error) {
+func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
        var err error
        var bind NativeBind
 
index a42813895a1a1e83c540135d9927864e127a0c2a..2b920bf663d8cd938147a6a524985a91f9fe73af 100644 (file)
@@ -55,11 +55,10 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
 }
 
 type NativeBind struct {
-       sock4        int
-       sock6        int
-       netlinkSock  int
-       lastEndpoint *NativeEndpoint
-       lastMark     uint32
+       sock4       int
+       sock6       int
+       netlinkSock int
+       lastMark    uint32
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
@@ -118,7 +117,7 @@ func createNetlinkRouteSocket() (int, error) {
 
 }
 
-func CreateBind(port uint16) (*NativeBind, uint16, error) {
+func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
        var err error
        var bind NativeBind
 
@@ -127,7 +126,7 @@ func CreateBind(port uint16) (*NativeBind, uint16, error) {
                return nil, 0, err
        }
 
-       go bind.routineRouteListener()
+       go bind.routineRouteListener(device)
 
        bind.sock6, port, err = create6(port)
        if err != nil {
@@ -171,8 +170,8 @@ func (bind *NativeBind) SetMark(value uint32) error {
 }
 
 func closeUnblock(fd int) error {
-       // shutdown to unblock readers
-       unix.Shutdown(fd, unix.SHUT_RD)
+       // shutdown to unblock readers and writers
+       unix.Shutdown(fd, unix.SHUT_RDWR)
        return unix.Close(fd)
 }
 
@@ -206,7 +205,6 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
                buff,
                &end,
        )
-       bind.lastEndpoint = &end
        return n, &end, err
 }
 
@@ -551,8 +549,8 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
        return size, nil
 }
 
-func (bind *NativeBind) routineRouteListener() {
-       // TODO: this function doesn't lock the endpoint it modifies
+func (bind *NativeBind) routineRouteListener(device *Device) {
+       var reqPeer map[uint32]*Peer
 
        for msg := make([]byte, 1<<16); ; {
                msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
@@ -570,12 +568,7 @@ func (bind *NativeBind) routineRouteListener() {
 
                        switch hdr.Type {
                        case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
-
-                               if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
-                                       break
-                               }
-
-                               if hdr.Seq == 0xff {
+                               if hdr.Seq <= MaxPeers {
                                        if uint(len(remain)) < uint(hdr.Len) {
                                                break
                                        }
@@ -591,54 +584,90 @@ func (bind *NativeBind) routineRouteListener() {
                                                        }
                                                        if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
                                                                ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
-                                                               if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
-                                                                       bind.lastEndpoint.ClearSrc()
+                                                               if reqPeer == nil {
+                                                                       break
+                                                               }
+                                                               peer, ok := reqPeer[hdr.Seq]
+                                                               if !ok {
+                                                                       break
+                                                               }
+                                                               peer.mutex.RLock()
+                                                               if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
+                                                                       peer.mutex.RUnlock()
+                                                                       break
+                                                               }
+                                                               if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
+                                                                       peer.mutex.RUnlock()
+                                                                       break
                                                                }
+                                                               if uint32(peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
+                                                                       peer.mutex.RUnlock()
+                                                                       break
+                                                               }
+                                                               peer.mutex.RUnlock()
+                                                               peer.mutex.Lock()
+                                                               peer.endpoint.(*NativeEndpoint).ClearSrc()
+                                                               peer.mutex.Unlock()
                                                        }
                                                        attr = attr[attrhdr.Len:]
                                                }
                                        }
                                        break
                                }
-
-                               nlmsg := struct {
-                                       hdr     unix.NlMsghdr
-                                       msg     unix.RtMsg
-                                       dsthdr  unix.RtAttr
-                                       dst     [4]byte
-                                       srchdr  unix.RtAttr
-                                       src     [4]byte
-                                       markhdr unix.RtAttr
-                                       mark    uint32
-                               }{
-                                       unix.NlMsghdr{
-                                               Type:  uint16(unix.RTM_GETROUTE),
-                                               Flags: unix.NLM_F_REQUEST,
-                                               Seq:   0xff,
-                                       },
-                                       unix.RtMsg{
-                                               Family:  unix.AF_INET,
-                                               Dst_len: 32,
-                                               Src_len: 32,
-                                       },
-                                       unix.RtAttr{
-                                               Len:  8,
-                                               Type: unix.RTA_DST,
-                                       },
-                                       bind.lastEndpoint.dst4().Addr,
-                                       unix.RtAttr{
-                                               Len:  8,
-                                               Type: unix.RTA_SRC,
-                                       },
-                                       bind.lastEndpoint.src4().src,
-                                       unix.RtAttr{
-                                               Len:  8,
-                                               Type: 0x10, //unix.RTA_MARK  TODO: add this to x/sys/unix
-                                       },
-                                       uint32(bind.lastMark),
-                               }
-                               nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
-                               unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                               reqPeer = make(map[uint32]*Peer)
+                               go func() {
+                                       device.peers.mutex.RLock()
+                                       i := uint32(1)
+                                       for _, peer := range device.peers.keyMap {
+                                               peer.mutex.RLock()
+                                               if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
+                                                       peer.mutex.RUnlock()
+                                                       continue
+                                               }
+                                               nlmsg := struct {
+                                                       hdr     unix.NlMsghdr
+                                                       msg     unix.RtMsg
+                                                       dsthdr  unix.RtAttr
+                                                       dst     [4]byte
+                                                       srchdr  unix.RtAttr
+                                                       src     [4]byte
+                                                       markhdr unix.RtAttr
+                                                       mark    uint32
+                                               }{
+                                                       unix.NlMsghdr{
+                                                               Type:  uint16(unix.RTM_GETROUTE),
+                                                               Flags: unix.NLM_F_REQUEST,
+                                                               Seq:   i,
+                                                       },
+                                                       unix.RtMsg{
+                                                               Family:  unix.AF_INET,
+                                                               Dst_len: 32,
+                                                               Src_len: 32,
+                                                       },
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: unix.RTA_DST,
+                                                       },
+                                                       peer.endpoint.(*NativeEndpoint).dst4().Addr,
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: unix.RTA_SRC,
+                                                       },
+                                                       peer.endpoint.(*NativeEndpoint).src4().src,
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: 0x10, //unix.RTA_MARK  TODO: add this to x/sys/unix
+                                                       },
+                                                       uint32(bind.lastMark),
+                                               }
+                                               nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+                                               reqPeer[i] = peer
+                                               peer.mutex.RUnlock()
+                                               i++
+                                               unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                                       }
+                                       device.peers.mutex.RUnlock()
+                               }()
                        }
                        remain = remain[hdr.Len:]
                }