]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: introduce new package that splits out the Bind and Endpoint types
authorDavid Crawshaw <crawshaw@tailscale.com>
Thu, 7 Nov 2019 16:13:05 +0000 (11:13 -0500)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 2 May 2020 07:46:42 +0000 (01:46 -0600)
The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
15 files changed:
conn/boundif_windows.go [moved from device/boundif_windows.go with 66% similarity]
conn/conn.go [new file with mode: 0644]
conn/conn_default.go [moved from device/conn_default.go with 94% similarity]
conn/conn_linux.go [moved from device/conn_linux.go with 63% similarity]
conn/mark_default.go [moved from device/mark_default.go with 93% similarity]
conn/mark_unix.go [moved from device/mark_unix.go with 98% similarity]
device/bind_test.go
device/bindsocketshim.go [new file with mode: 0644]
device/conn.go [deleted file]
device/device.go
device/peer.go
device/receive.go
device/sticky_default.go [new file with mode: 0644]
device/sticky_linux.go [new file with mode: 0644]
device/uapi.go

similarity index 66%
rename from device/boundif_windows.go
rename to conn/boundif_windows.go
index 69084152815b6bcdd608da711dc2e498d8465720..fe38d05f5192675a5225f5149f3cb28f77b4811b 100644 (file)
@@ -3,11 +3,10 @@
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
  */
 
-package device
+package conn
 
 import (
        "encoding/binary"
-       "errors"
        "unsafe"
 
        "golang.org/x/sys/windows"
@@ -18,17 +17,13 @@ const (
        sockoptIPV6_UNICAST_IF = 31
 )
 
-func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
        /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
        bytes := make([]byte, 4)
        binary.BigEndian.PutUint32(bytes, interfaceIndex)
        interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
 
-       if device.net.bind == nil {
-               return errors.New("Bind is not yet initialized")
-       }
-
-       sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+       sysconn, err := bind.ipv4.SyscallConn()
        if err != nil {
                return err
        }
@@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
        if err != nil {
                return err
        }
-       device.net.bind.(*nativeBind).blackhole4 = blackhole
+       bind.blackhole4 = blackhole
        return nil
 }
 
-func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
-       sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
+func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+       sysconn, err := bind.ipv6.SyscallConn()
        if err != nil {
                return err
        }
@@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
        if err != nil {
                return err
        }
-       device.net.bind.(*nativeBind).blackhole6 = blackhole
+       bind.blackhole6 = blackhole
        return nil
 }
diff --git a/conn/conn.go b/conn/conn.go
new file mode 100644 (file)
index 0000000..6b7db12
--- /dev/null
@@ -0,0 +1,101 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+// Package conn implements WireGuard's network connections.
+package conn
+
+import (
+       "errors"
+       "net"
+       "strings"
+)
+
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
+type Bind interface {
+       // LastMark reports the last mark set for this Bind.
+       LastMark() uint32
+
+       // SetMark sets the mark for each packet sent through this Bind.
+       // This mark is passed to the kernel as the socket option SO_MARK.
+       SetMark(mark uint32) error
+
+       // ReceiveIPv6 reads an IPv6 UDP packet into b.
+       //
+       // It reports the number of bytes read, n,
+       // the packet source address ep,
+       // and any error.
+       ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
+
+       // ReceiveIPv4 reads an IPv4 UDP packet into b.
+       //
+       // It reports the number of bytes read, n,
+       // the packet source address ep,
+       // and any error.
+       ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
+
+       // Send writes a packet b to address ep.
+       Send(b []byte, ep Endpoint) error
+
+       // Close closes the Bind connection.
+       Close() error
+}
+
+// CreateBind creates a Bind bound to a port.
+//
+// The value actualPort reports the actual port number the Bind
+// object gets bound to.
+func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
+       return createBind(port)
+}
+
+// BindToInterface is implemented by Bind objects that support being
+// tied to a single network interface.
+type BindToInterface interface {
+       BindToInterface4(interfaceIndex uint32, blackhole bool) error
+       BindToInterface6(interfaceIndex uint32, blackhole bool) error
+}
+
+// An Endpoint maintains the source/destination caching for a peer.
+//
+//     dst : the remote address of a peer ("endpoint" in uapi terminology)
+//     src : the local address from which datagrams originate going to the peer
+type Endpoint interface {
+       ClearSrc()           // clears the source address
+       SrcToString() string // returns the local source address (ip:port)
+       DstToString() string // returns the destination address (ip:port)
+       DstToBytes() []byte  // used for mac2 cookie calculations
+       DstIP() net.IP
+       SrcIP() net.IP
+}
+
+func parseEndpoint(s string) (*net.UDPAddr, error) {
+       // ensure that the host is an IP address
+
+       host, _, err := net.SplitHostPort(s)
+       if err != nil {
+               return nil, err
+       }
+       if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
+               // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
+               // trying to make sure with a small sanity test that this is a real IP address and
+               // not something that's likely to incur DNS lookups.
+               host = host[:i]
+       }
+       if ip := net.ParseIP(host); ip == nil {
+               return nil, errors.New("Failed to parse IP address: " + host)
+       }
+
+       // parse address and port
+
+       addr, err := net.ResolveUDPAddr("udp", s)
+       if err != nil {
+               return nil, err
+       }
+       ip4 := addr.IP.To4()
+       if ip4 != nil {
+               addr.IP = ip4
+       }
+       return addr, err
+}
similarity index 94%
rename from device/conn_default.go
rename to conn/conn_default.go
index 661f57d97eea5084f62b32bfe8ad3dae7209ac7c..bad9d4df8792ec1600f407b0a5c99095eb8fbc73 100644 (file)
@@ -5,7 +5,7 @@
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
  */
 
-package device
+package conn
 
 import (
        "net"
@@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string {
 }
 
 func listenNet(network string, port int) (*net.UDPConn, int, error) {
-
-       // listen
-
        conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
        if err != nil {
                return nil, 0, err
        }
 
-       // retrieve port
-
+       // Retrieve port.
+       // TODO(crawshaw): under what circumstances is this necessary?
        laddr := conn.LocalAddr()
        uaddr, err := net.ResolveUDPAddr(
                laddr.Network(),
@@ -100,7 +97,7 @@ func extractErrno(err error) error {
        return syscallErr.Err
 }
 
-func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
+func createBind(uport uint16) (Bind, uint16, error) {
        var err error
        var bind nativeBind
 
@@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error {
        return err2
 }
 
+func (bind *nativeBind) LastMark() uint32 { return 0 }
+
 func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        if bind.ipv4 == nil {
                return 0, nil, syscall.EAFNOSUPPORT
similarity index 63%
rename from device/conn_linux.go
rename to conn/conn_linux.go
index e90b0e35b456c1a3cd752fa10e0c441dbca65c68..523da4a4555dd167046cabde671739fd4438e812 100644 (file)
@@ -3,18 +3,9 @@
 /* SPDX-License-Identifier: MIT
  *
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- *
- * This implements userspace semantics of "sticky sockets", modeled after
- * WireGuard's kernelspace implementation. This is more or less a straight port
- * of the sticky-sockets.c example code:
- * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
- *
- * Currently there is no way to achieve this within the net package:
- * See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
  */
 
-package device
+package conn
 
 import (
        "errors"
@@ -25,7 +16,6 @@ import (
        "unsafe"
 
        "golang.org/x/sys/unix"
-       "golang.zx2c4.com/wireguard/rwcancel"
 )
 
 const (
@@ -33,8 +23,8 @@ const (
 )
 
 type IPv4Source struct {
-       src     [4]byte
-       ifindex int32
+       Src     [4]byte
+       Ifindex int32
 }
 
 type IPv6Source struct {
@@ -49,6 +39,10 @@ type NativeEndpoint struct {
        isV6 bool
 }
 
+func (endpoint *NativeEndpoint) Src4() *IPv4Source         { return endpoint.src4() }
+func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
+func (endpoint *NativeEndpoint) IsV6() bool                { return endpoint.isV6 }
+
 func (endpoint *NativeEndpoint) src4() *IPv4Source {
        return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
 }
@@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
 }
 
 type nativeBind struct {
-       sock4         int
-       sock6         int
-       netlinkSock   int
-       netlinkCancel *rwcancel.RWCancel
-       lastMark      uint32
+       sock4    int
+       sock6    int
+       lastMark uint32
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
@@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
        return nil, errors.New("Invalid IP address")
 }
 
-func createNetlinkRouteSocket() (int, error) {
-       sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
-       if err != nil {
-               return -1, err
-       }
-       saddr := &unix.SockaddrNetlink{
-               Family: unix.AF_NETLINK,
-               Groups: unix.RTMGRP_IPV4_ROUTE,
-       }
-       err = unix.Bind(sock, saddr)
-       if err != nil {
-               unix.Close(sock)
-               return -1, err
-       }
-       return sock, nil
-
-}
-
-func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
+func createBind(port uint16) (Bind, uint16, error) {
        var err error
        var bind nativeBind
        var newPort uint16
 
-       bind.netlinkSock, err = createNetlinkRouteSocket()
-       if err != nil {
-               return nil, 0, err
-       }
-       bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
-       if err != nil {
-               unix.Close(bind.netlinkSock)
-               return nil, 0, err
-       }
-
-       go bind.routineRouteListener(device)
-
-       // attempt ipv6 bind, update port if successful
-
+       // Attempt ipv6 bind, update port if successful.
        bind.sock6, newPort, err = create6(port)
        if err != nil {
                if err != syscall.EAFNOSUPPORT {
-                       bind.netlinkCancel.Cancel()
                        return nil, 0, err
                }
        } else {
                port = newPort
        }
 
-       // attempt ipv4 bind, update port if successful
-
+       // Attempt ipv4 bind, update port if successful.
        bind.sock4, newPort, err = create4(port)
        if err != nil {
                if err != syscall.EAFNOSUPPORT {
-                       bind.netlinkCancel.Cancel()
                        unix.Close(bind.sock6)
                        return nil, 0, err
                }
@@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
        return &bind, port, nil
 }
 
+func (bind *nativeBind) LastMark() uint32 {
+       return bind.lastMark
+}
+
 func (bind *nativeBind) SetMark(value uint32) error {
        if bind.sock6 != -1 {
                err := unix.SetsockoptInt(
@@ -216,22 +178,18 @@ func closeUnblock(fd int) error {
 }
 
 func (bind *nativeBind) Close() error {
-       var err1, err2, err3 error
+       var err1, err2 error
        if bind.sock6 != -1 {
                err1 = closeUnblock(bind.sock6)
        }
        if bind.sock4 != -1 {
                err2 = closeUnblock(bind.sock4)
        }
-       err3 = bind.netlinkCancel.Cancel()
 
        if err1 != nil {
                return err1
        }
-       if err2 != nil {
-               return err2
-       }
-       return err3
+       return err2
 }
 
 func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
@@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
 func (end *NativeEndpoint) SrcIP() net.IP {
        if !end.isV6 {
                return net.IPv4(
-                       end.src4().src[0],
-                       end.src4().src[1],
-                       end.src4().src[2],
-                       end.src4().src[3],
+                       end.src4().Src[0],
+                       end.src4().Src[1],
+                       end.src4().Src[2],
+                       end.src4().Src[3],
                )
        } else {
                return end.src6().src[:]
@@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
                        Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
                },
                unix.Inet4Pktinfo{
-                       Spec_dst: end.src4().src,
-                       Ifindex:  end.src4().ifindex,
+                       Spec_dst: end.src4().Src,
+                       Ifindex:  end.src4().Ifindex,
                },
        }
 
@@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
        if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
                cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
                cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-               end.src4().src = cmsg.pktinfo.Spec_dst
-               end.src4().ifindex = cmsg.pktinfo.Ifindex
+               end.src4().Src = cmsg.pktinfo.Spec_dst
+               end.src4().Ifindex = cmsg.pktinfo.Ifindex
        }
 
        return size, nil
@@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
        return size, nil
 }
-
-func (bind *nativeBind) routineRouteListener(device *Device) {
-       type peerEndpointPtr struct {
-               peer     *Peer
-               endpoint *Endpoint
-       }
-       var reqPeer map[uint32]peerEndpointPtr
-       var reqPeerLock sync.Mutex
-
-       defer unix.Close(bind.netlinkSock)
-
-       for msg := make([]byte, 1<<16); ; {
-               var err error
-               var msgn int
-               for {
-                       msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
-                       if err == nil || !rwcancel.RetryAfterError(err) {
-                               break
-                       }
-                       if !bind.netlinkCancel.ReadyRead() {
-                               return
-                       }
-               }
-               if err != nil {
-                       return
-               }
-
-               for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
-
-                       hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
-
-                       if uint(hdr.Len) > uint(len(remain)) {
-                               break
-                       }
-
-                       switch hdr.Type {
-                       case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
-                               if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
-                                       if uint(len(remain)) < uint(hdr.Len) {
-                                               break
-                                       }
-                                       if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
-                                               attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
-                                               for {
-                                                       if uint(len(attr)) < uint(unix.SizeofRtAttr) {
-                                                               break
-                                                       }
-                                                       attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
-                                                       if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
-                                                               break
-                                                       }
-                                                       if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
-                                                               ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
-                                                               reqPeerLock.Lock()
-                                                               if reqPeer == nil {
-                                                                       reqPeerLock.Unlock()
-                                                                       break
-                                                               }
-                                                               pePtr, ok := reqPeer[hdr.Seq]
-                                                               reqPeerLock.Unlock()
-                                                               if !ok {
-                                                                       break
-                                                               }
-                                                               pePtr.peer.Lock()
-                                                               if &pePtr.peer.endpoint != pePtr.endpoint {
-                                                                       pePtr.peer.Unlock()
-                                                                       break
-                                                               }
-                                                               if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
-                                                                       pePtr.peer.Unlock()
-                                                                       break
-                                                               }
-                                                               pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
-                                                               pePtr.peer.Unlock()
-                                                       }
-                                                       attr = attr[attrhdr.Len:]
-                                               }
-                                       }
-                                       break
-                               }
-                               reqPeerLock.Lock()
-                               reqPeer = make(map[uint32]peerEndpointPtr)
-                               reqPeerLock.Unlock()
-                               go func() {
-                                       device.peers.RLock()
-                                       i := uint32(1)
-                                       for _, peer := range device.peers.keyMap {
-                                               peer.RLock()
-                                               if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
-                                                       peer.RUnlock()
-                                                       continue
-                                               }
-                                               if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
-                                                       peer.RUnlock()
-                                                       break
-                                               }
-                                               nlmsg := struct {
-                                                       hdr     unix.NlMsghdr
-                                                       msg     unix.RtMsg
-                                                       dsthdr  unix.RtAttr
-                                                       dst     [4]byte
-                                                       srchdr  unix.RtAttr
-                                                       src     [4]byte
-                                                       markhdr unix.RtAttr
-                                                       mark    uint32
-                                               }{
-                                                       unix.NlMsghdr{
-                                                               Type:  uint16(unix.RTM_GETROUTE),
-                                                               Flags: unix.NLM_F_REQUEST,
-                                                               Seq:   i,
-                                                       },
-                                                       unix.RtMsg{
-                                                               Family:  unix.AF_INET,
-                                                               Dst_len: 32,
-                                                               Src_len: 32,
-                                                       },
-                                                       unix.RtAttr{
-                                                               Len:  8,
-                                                               Type: unix.RTA_DST,
-                                                       },
-                                                       peer.endpoint.(*NativeEndpoint).dst4().Addr,
-                                                       unix.RtAttr{
-                                                               Len:  8,
-                                                               Type: unix.RTA_SRC,
-                                                       },
-                                                       peer.endpoint.(*NativeEndpoint).src4().src,
-                                                       unix.RtAttr{
-                                                               Len:  8,
-                                                               Type: unix.RTA_MARK,
-                                                       },
-                                                       uint32(bind.lastMark),
-                                               }
-                                               nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
-                                               reqPeerLock.Lock()
-                                               reqPeer[i] = peerEndpointPtr{
-                                                       peer:     peer,
-                                                       endpoint: &peer.endpoint,
-                                               }
-                                               reqPeerLock.Unlock()
-                                               peer.RUnlock()
-                                               i++
-                                               _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
-                                               if err != nil {
-                                                       break
-                                               }
-                                       }
-                                       device.peers.RUnlock()
-                               }()
-                       }
-                       remain = remain[hdr.Len:]
-               }
-       }
-}
similarity index 93%
rename from device/mark_default.go
rename to conn/mark_default.go
index 7de2524c01b5124e90b776a953b7a48f139ffd5a..fc41ba9931bd384c5b0d00d0ceff76699ca0392a 100644 (file)
@@ -5,7 +5,7 @@
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
  */
 
-package device
+package conn
 
 func (bind *nativeBind) SetMark(mark uint32) error {
        return nil
similarity index 98%
rename from device/mark_unix.go
rename to conn/mark_unix.go
index 669b3281464e6004076b6de9905fa362111cf412..5334582e93ea6feb5b9501fd04353f6b9eae9014 100644 (file)
@@ -5,7 +5,7 @@
  * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
  */
 
-package device
+package conn
 
 import (
        "runtime"
index 0c2e2cfd282433f9e29303aabb85513ae66b1cc2..c5f7f68a9772bcd086fbd4a751798bd797c6040e 100644 (file)
@@ -5,11 +5,15 @@
 
 package device
 
-import "errors"
+import (
+       "errors"
+
+       "golang.zx2c4.com/wireguard/conn"
+)
 
 type DummyDatagram struct {
        msg      []byte
-       endpoint Endpoint
+       endpoint conn.Endpoint
        world    bool // better type
 }
 
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
        return nil
 }
 
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
        datagram, ok := <-b.in6
        if !ok {
                return 0, nil, errors.New("closed")
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        return len(datagram.msg), datagram.endpoint, nil
 }
 
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
        datagram, ok := <-b.in4
        if !ok {
                return 0, nil, errors.New("closed")
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
        return nil
 }
 
-func (b *DummyBind) Send(buff []byte, end Endpoint) error {
+func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
        return nil
 }
diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go
new file mode 100644 (file)
index 0000000..c4dd4ef
--- /dev/null
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+       "errors"
+
+       "golang.zx2c4.com/wireguard/conn"
+)
+
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
+func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+       if device.net.bind == nil {
+               return errors.New("Bind is not yet initialized")
+       }
+
+       if iface, ok := device.net.bind.(conn.BindToInterface); ok {
+               return iface.BindToInterface4(interfaceIndex, blackhole)
+       }
+       return nil
+}
+
+// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
+func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+       if device.net.bind == nil {
+               return errors.New("Bind is not yet initialized")
+       }
+
+       if iface, ok := device.net.bind.(conn.BindToInterface); ok {
+               return iface.BindToInterface6(interfaceIndex, blackhole)
+       }
+       return nil
+}
diff --git a/device/conn.go b/device/conn.go
deleted file mode 100644 (file)
index 7b341f6..0000000
+++ /dev/null
@@ -1,187 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
-       "errors"
-       "net"
-       "strings"
-
-       "golang.org/x/net/ipv4"
-       "golang.org/x/net/ipv6"
-)
-
-const (
-       ConnRoutineNumber = 2
-)
-
-/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
- */
-type Bind interface {
-       SetMark(value uint32) error
-       ReceiveIPv6(buff []byte) (int, Endpoint, error)
-       ReceiveIPv4(buff []byte) (int, Endpoint, error)
-       Send(buff []byte, end Endpoint) error
-       Close() error
-}
-
-/* An Endpoint maintains the source/destination caching for a peer
- *
- * dst : the remote address of a peer ("endpoint" in uapi terminology)
- * src : the local address from which datagrams originate going to the peer
- */
-type Endpoint interface {
-       ClearSrc()           // clears the source address
-       SrcToString() string // returns the local source address (ip:port)
-       DstToString() string // returns the destination address (ip:port)
-       DstToBytes() []byte  // used for mac2 cookie calculations
-       DstIP() net.IP
-       SrcIP() net.IP
-}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
-       // ensure that the host is an IP address
-
-       host, _, err := net.SplitHostPort(s)
-       if err != nil {
-               return nil, err
-       }
-       if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
-               // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
-               // trying to make sure with a small sanity test that this is a real IP address and
-               // not something that's likely to incur DNS lookups.
-               host = host[:i]
-       }
-       if ip := net.ParseIP(host); ip == nil {
-               return nil, errors.New("Failed to parse IP address: " + host)
-       }
-
-       // parse address and port
-
-       addr, err := net.ResolveUDPAddr("udp", s)
-       if err != nil {
-               return nil, err
-       }
-       ip4 := addr.IP.To4()
-       if ip4 != nil {
-               addr.IP = ip4
-       }
-       return addr, err
-}
-
-func unsafeCloseBind(device *Device) error {
-       var err error
-       netc := &device.net
-       if netc.bind != nil {
-               err = netc.bind.Close()
-               netc.bind = nil
-       }
-       netc.stopping.Wait()
-       return err
-}
-
-func (device *Device) BindSetMark(mark uint32) error {
-
-       device.net.Lock()
-       defer device.net.Unlock()
-
-       // check if modified
-
-       if device.net.fwmark == mark {
-               return nil
-       }
-
-       // update fwmark on existing bind
-
-       device.net.fwmark = mark
-       if device.isUp.Get() && device.net.bind != nil {
-               if err := device.net.bind.SetMark(mark); err != nil {
-                       return err
-               }
-       }
-
-       // clear cached source addresses
-
-       device.peers.RLock()
-       for _, peer := range device.peers.keyMap {
-               peer.Lock()
-               defer peer.Unlock()
-               if peer.endpoint != nil {
-                       peer.endpoint.ClearSrc()
-               }
-       }
-       device.peers.RUnlock()
-
-       return nil
-}
-
-func (device *Device) BindUpdate() error {
-
-       device.net.Lock()
-       defer device.net.Unlock()
-
-       // close existing sockets
-
-       if err := unsafeCloseBind(device); err != nil {
-               return err
-       }
-
-       // open new sockets
-
-       if device.isUp.Get() {
-
-               // bind to new port
-
-               var err error
-               netc := &device.net
-               netc.bind, netc.port, err = CreateBind(netc.port, device)
-               if err != nil {
-                       netc.bind = nil
-                       netc.port = 0
-                       return err
-               }
-
-               // set fwmark
-
-               if netc.fwmark != 0 {
-                       err = netc.bind.SetMark(netc.fwmark)
-                       if err != nil {
-                               return err
-                       }
-               }
-
-               // clear cached source addresses
-
-               device.peers.RLock()
-               for _, peer := range device.peers.keyMap {
-                       peer.Lock()
-                       defer peer.Unlock()
-                       if peer.endpoint != nil {
-                               peer.endpoint.ClearSrc()
-                       }
-               }
-               device.peers.RUnlock()
-
-               // start receiving routines
-
-               device.net.starting.Add(ConnRoutineNumber)
-               device.net.stopping.Add(ConnRoutineNumber)
-               go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
-               go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
-               device.net.starting.Wait()
-
-               device.log.Debug.Println("UDP bind has been updated")
-       }
-
-       return nil
-}
-
-func (device *Device) BindClose() error {
-       device.net.Lock()
-       err := unsafeCloseBind(device)
-       device.net.Unlock()
-       return err
-}
index 8c08f1c34f8c679326c7cb3f5a888e70d6788f0f..a9fedea86b3481bf9f581e6f850b467536de1efc 100644 (file)
@@ -11,15 +11,14 @@ import (
        "sync/atomic"
        "time"
 
+       "golang.org/x/net/ipv4"
+       "golang.org/x/net/ipv6"
+       "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ratelimiter"
+       "golang.zx2c4.com/wireguard/rwcancel"
        "golang.zx2c4.com/wireguard/tun"
 )
 
-const (
-       DeviceRoutineNumberPerCPU     = 3
-       DeviceRoutineNumberAdditional = 2
-)
-
 type Device struct {
        isUp     AtomicBool // device is (going) up
        isClosed AtomicBool // device is closed? (acting as guard)
@@ -39,9 +38,10 @@ type Device struct {
                starting sync.WaitGroup
                stopping sync.WaitGroup
                sync.RWMutex
-               bind   Bind   // bind interface
-               port   uint16 // listening port
-               fwmark uint32 // mark value (0 = disabled)
+               bind          conn.Bind // bind interface
+               netlinkCancel *rwcancel.RWCancel
+               port          uint16 // listening port
+               fwmark        uint32 // mark value (0 = disabled)
        }
 
        staticIdentity struct {
@@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        cpus := runtime.NumCPU()
        device.state.starting.Wait()
        device.state.stopping.Wait()
-       device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
-       device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
        for i := 0; i < cpus; i += 1 {
+               device.state.starting.Add(3)
+               device.state.stopping.Add(3)
                go device.RoutineEncryption()
                go device.RoutineDecryption()
                go device.RoutineHandshake()
        }
 
+       device.state.starting.Add(2)
+       device.state.stopping.Add(2)
        go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
 
@@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
        }
        device.peers.RUnlock()
 }
+
+func unsafeCloseBind(device *Device) error {
+       var err error
+       netc := &device.net
+       if netc.netlinkCancel != nil {
+               netc.netlinkCancel.Cancel()
+       }
+       if netc.bind != nil {
+               err = netc.bind.Close()
+               netc.bind = nil
+       }
+       netc.stopping.Wait()
+       return err
+}
+
+func (device *Device) BindSetMark(mark uint32) error {
+
+       device.net.Lock()
+       defer device.net.Unlock()
+
+       // check if modified
+
+       if device.net.fwmark == mark {
+               return nil
+       }
+
+       // update fwmark on existing bind
+
+       device.net.fwmark = mark
+       if device.isUp.Get() && device.net.bind != nil {
+               if err := device.net.bind.SetMark(mark); err != nil {
+                       return err
+               }
+       }
+
+       // clear cached source addresses
+
+       device.peers.RLock()
+       for _, peer := range device.peers.keyMap {
+               peer.Lock()
+               defer peer.Unlock()
+               if peer.endpoint != nil {
+                       peer.endpoint.ClearSrc()
+               }
+       }
+       device.peers.RUnlock()
+
+       return nil
+}
+
+func (device *Device) BindUpdate() error {
+
+       device.net.Lock()
+       defer device.net.Unlock()
+
+       // close existing sockets
+
+       if err := unsafeCloseBind(device); err != nil {
+               return err
+       }
+
+       // open new sockets
+
+       if device.isUp.Get() {
+
+               // bind to new port
+
+               var err error
+               netc := &device.net
+               netc.bind, netc.port, err = conn.CreateBind(netc.port)
+               if err != nil {
+                       netc.bind = nil
+                       netc.port = 0
+                       return err
+               }
+               netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+               if err != nil {
+                       netc.bind.Close()
+                       netc.bind = nil
+                       netc.port = 0
+                       return err
+               }
+
+               // set fwmark
+
+               if netc.fwmark != 0 {
+                       err = netc.bind.SetMark(netc.fwmark)
+                       if err != nil {
+                               return err
+                       }
+               }
+
+               // clear cached source addresses
+
+               device.peers.RLock()
+               for _, peer := range device.peers.keyMap {
+                       peer.Lock()
+                       defer peer.Unlock()
+                       if peer.endpoint != nil {
+                               peer.endpoint.ClearSrc()
+                       }
+               }
+               device.peers.RUnlock()
+
+               // start receiving routines
+
+               device.net.starting.Add(2)
+               device.net.stopping.Add(2)
+               go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
+               go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+               device.net.starting.Wait()
+
+               device.log.Debug.Println("UDP bind has been updated")
+       }
+
+       return nil
+}
+
+func (device *Device) BindClose() error {
+       device.net.Lock()
+       err := unsafeCloseBind(device)
+       device.net.Unlock()
+       return err
+}
index 19434cda3dc11059de71dd766413a6916e4ae007..79d4981812bfc2ea82a8a6fc2752e090ad613838 100644 (file)
@@ -12,6 +12,8 @@ import (
        "sync"
        "sync/atomic"
        "time"
+
+       "golang.zx2c4.com/wireguard/conn"
 )
 
 const (
@@ -24,7 +26,7 @@ type Peer struct {
        keypairs                    Keypairs
        handshake                   Handshake
        device                      *Device
-       endpoint                    Endpoint
+       endpoint                    conn.Endpoint
        persistentKeepaliveInterval uint16
 
        // These fields are accessed with atomic operations, which must be
@@ -290,7 +292,7 @@ func (peer *Peer) Stop() {
 
 var RoamingDisabled bool
 
-func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
+func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
        if RoamingDisabled {
                return
        }
index 7d0693e1f9e3206694333a61ca65fd3f5d07896b..4818d649ed983f54caeff4e859bbf7281d3e53aa 100644 (file)
@@ -17,12 +17,13 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
+       "golang.zx2c4.com/wireguard/conn"
 )
 
 type QueueHandshakeElement struct {
        msgType  uint32
        packet   []byte
-       endpoint Endpoint
+       endpoint conn.Endpoint
        buffer   *[MaxMessageSize]byte
 }
 
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
        packet   []byte
        counter  uint64
        keypair  *Keypair
-       endpoint Endpoint
+       endpoint conn.Endpoint
 }
 
 func (elem *QueueInboundElement) Drop() {
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
  * Every time the bind is updated a new routine is started for
  * IPv4 and IPv6 (separately)
  */
-func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
+func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
 
        logDebug := device.log.Debug
        defer func() {
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
        var (
                err      error
                size     int
-               endpoint Endpoint
+               endpoint conn.Endpoint
        )
 
        for {
diff --git a/device/sticky_default.go b/device/sticky_default.go
new file mode 100644 (file)
index 0000000..1cc52f6
--- /dev/null
@@ -0,0 +1,12 @@
+// +build !linux
+
+package device
+
+import (
+       "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+       return nil, nil
+}
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
new file mode 100644 (file)
index 0000000..f9522c2
--- /dev/null
@@ -0,0 +1,215 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ *
+ * This implements userspace semantics of "sticky sockets", modeled after
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
+ */
+
+package device
+
+import (
+       "sync"
+       "unsafe"
+
+       "golang.org/x/sys/unix"
+       "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+       netlinkSock, err := createNetlinkRouteSocket()
+       if err != nil {
+               return nil, err
+       }
+       netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
+       if err != nil {
+               unix.Close(netlinkSock)
+               return nil, err
+       }
+
+       go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
+
+       return netlinkCancel, nil
+}
+
+func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
+       type peerEndpointPtr struct {
+               peer     *Peer
+               endpoint *conn.Endpoint
+       }
+       var reqPeer map[uint32]peerEndpointPtr
+       var reqPeerLock sync.Mutex
+
+       defer unix.Close(netlinkSock)
+
+       for msg := make([]byte, 1<<16); ; {
+               var err error
+               var msgn int
+               for {
+                       msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
+                       if err == nil || !rwcancel.RetryAfterError(err) {
+                               break
+                       }
+                       if !netlinkCancel.ReadyRead() {
+                               return
+                       }
+               }
+               if err != nil {
+                       return
+               }
+
+               for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+                       hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+                       if uint(hdr.Len) > uint(len(remain)) {
+                               break
+                       }
+
+                       switch hdr.Type {
+                       case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+                               if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
+                                       if uint(len(remain)) < uint(hdr.Len) {
+                                               break
+                                       }
+                                       if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+                                               attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+                                               for {
+                                                       if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+                                                               break
+                                                       }
+                                                       attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+                                                       if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+                                                               break
+                                                       }
+                                                       if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+                                                               ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+                                                               reqPeerLock.Lock()
+                                                               if reqPeer == nil {
+                                                                       reqPeerLock.Unlock()
+                                                                       break
+                                                               }
+                                                               pePtr, ok := reqPeer[hdr.Seq]
+                                                               reqPeerLock.Unlock()
+                                                               if !ok {
+                                                                       break
+                                                               }
+                                                               pePtr.peer.Lock()
+                                                               if &pePtr.peer.endpoint != pePtr.endpoint {
+                                                                       pePtr.peer.Unlock()
+                                                                       break
+                                                               }
+                                                               if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
+                                                                       pePtr.peer.Unlock()
+                                                                       break
+                                                               }
+                                                               pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
+                                                               pePtr.peer.Unlock()
+                                                       }
+                                                       attr = attr[attrhdr.Len:]
+                                               }
+                                       }
+                                       break
+                               }
+                               reqPeerLock.Lock()
+                               reqPeer = make(map[uint32]peerEndpointPtr)
+                               reqPeerLock.Unlock()
+                               go func() {
+                                       device.peers.RLock()
+                                       i := uint32(1)
+                                       for _, peer := range device.peers.keyMap {
+                                               peer.RLock()
+                                               if peer.endpoint == nil {
+                                                       peer.RUnlock()
+                                                       continue
+                                               }
+                                               nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
+                                               if nativeEP == nil {
+                                                       peer.RUnlock()
+                                                       continue
+                                               }
+                                               if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
+                                                       peer.RUnlock()
+                                                       break
+                                               }
+                                               nlmsg := struct {
+                                                       hdr     unix.NlMsghdr
+                                                       msg     unix.RtMsg
+                                                       dsthdr  unix.RtAttr
+                                                       dst     [4]byte
+                                                       srchdr  unix.RtAttr
+                                                       src     [4]byte
+                                                       markhdr unix.RtAttr
+                                                       mark    uint32
+                                               }{
+                                                       unix.NlMsghdr{
+                                                               Type:  uint16(unix.RTM_GETROUTE),
+                                                               Flags: unix.NLM_F_REQUEST,
+                                                               Seq:   i,
+                                                       },
+                                                       unix.RtMsg{
+                                                               Family:  unix.AF_INET,
+                                                               Dst_len: 32,
+                                                               Src_len: 32,
+                                                       },
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: unix.RTA_DST,
+                                                       },
+                                                       nativeEP.Dst4().Addr,
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: unix.RTA_SRC,
+                                                       },
+                                                       nativeEP.Src4().Src,
+                                                       unix.RtAttr{
+                                                               Len:  8,
+                                                               Type: unix.RTA_MARK,
+                                                       },
+                                                       uint32(bind.LastMark()),
+                                               }
+                                               nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+                                               reqPeerLock.Lock()
+                                               reqPeer[i] = peerEndpointPtr{
+                                                       peer:     peer,
+                                                       endpoint: &peer.endpoint,
+                                               }
+                                               reqPeerLock.Unlock()
+                                               peer.RUnlock()
+                                               i++
+                                               _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+                                               if err != nil {
+                                                       break
+                                               }
+                                       }
+                                       device.peers.RUnlock()
+                               }()
+                       }
+                       remain = remain[hdr.Len:]
+               }
+       }
+}
+
+func createNetlinkRouteSocket() (int, error) {
+       sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+       if err != nil {
+               return -1, err
+       }
+       saddr := &unix.SockaddrNetlink{
+               Family: unix.AF_NETLINK,
+               Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+       }
+       err = unix.Bind(sock, saddr)
+       if err != nil {
+               unix.Close(sock)
+               return -1, err
+       }
+       return sock, nil
+}
index 72611ab5e93cdf2432a3bbfc3d9dcbed7dba0840..6cdccd61589086b868ae12f434d2f596a39789d5 100644 (file)
@@ -15,6 +15,7 @@ import (
        "sync/atomic"
        "time"
 
+       "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ipc"
 )
 
@@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
                                err := func() error {
                                        peer.Lock()
                                        defer peer.Unlock()
-                                       endpoint, err := CreateEndpoint(value)
+                                       endpoint, err := conn.CreateEndpoint(value)
                                        if err != nil {
                                                return err
                                        }