]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Cancelable netlink writes and better response correlation
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 12:18:26 +0000 (14:18 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 12:18:26 +0000 (14:18 +0200)
conn_linux.go

index e30631f617cc3053c1f78bf41ec35dd554b7bea3..b0c6d962833bada582e601075aef9203a3bbd411 100644 (file)
@@ -545,7 +545,11 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 }
 
 func (bind *NativeBind) routineRouteListener(device *Device) {
-       var reqPeer map[uint32]*Peer
+       type peerEndpointPtr struct {
+               peer     *Peer
+               endpoint *Endpoint
+       }
+       var reqPeer map[uint32]peerEndpointPtr
 
        defer unix.Close(bind.netlinkSock)
 
@@ -594,34 +598,28 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
                                                                if reqPeer == nil {
                                                                        break
                                                                }
-                                                               peer, ok := reqPeer[hdr.Seq]
+                                                               pePtr, ok := reqPeer[hdr.Seq]
                                                                if !ok {
                                                                        break
                                                                }
-                                                               peer.mutex.RLock()
-                                                               if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
-                                                                       peer.mutex.RUnlock()
-                                                                       break
-                                                               }
-                                                               if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
-                                                                       peer.mutex.RUnlock()
+                                                               pePtr.peer.mutex.Lock()
+                                                               if &pePtr.peer.endpoint != pePtr.endpoint {
+                                                                       pePtr.peer.mutex.Unlock()
                                                                        break
                                                                }
-                                                               if uint32(peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
-                                                                       peer.mutex.RUnlock()
+                                                               if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
+                                                                       pePtr.peer.mutex.Unlock()
                                                                        break
                                                                }
-                                                               peer.mutex.RUnlock()
-                                                               peer.mutex.Lock()
-                                                               peer.endpoint.(*NativeEndpoint).ClearSrc()
-                                                               peer.mutex.Unlock()
+                                                               pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
+                                                               pePtr.peer.mutex.Unlock()
                                                        }
                                                        attr = attr[attrhdr.Len:]
                                                }
                                        }
                                        break
                                }
-                               reqPeer = make(map[uint32]*Peer)
+                               reqPeer = make(map[uint32]peerEndpointPtr)
                                go func() {
                                        device.peers.mutex.RLock()
                                        i := uint32(1)
@@ -672,10 +670,16 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
                                                        uint32(bind.lastMark),
                                                }
                                                nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
-                                               reqPeer[i] = peer
+                                               reqPeer[i] = peerEndpointPtr{
+                                                       peer:     peer,
+                                                       endpoint: &peer.endpoint,
+                                               }
                                                peer.mutex.RUnlock()
                                                i++
-                                               unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                                               _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                                               if err != nil {
+                                                       break
+                                               }
                                        }
                                        device.peers.mutex.RUnlock()
                                }()