]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: store IP_PKTINFO cmsg in StdNetendpoint src
authorJames Tucker <james@tailscale.com>
Wed, 19 Apr 2023 05:29:55 +0000 (22:29 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 27 Jun 2023 15:48:32 +0000 (17:48 +0200)
Replace the src storage inside StdNetEndpoint with a copy of the raw
control message buffer, to reduce allocation and perform less work on a
per-packet basis.

Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bind_std.go
conn/sticky_default.go
conn/sticky_linux.go
conn/sticky_linux_test.go

index 69789b33fe65b51d9652513c35d937b44235976a..c701ef8724007bda09a7f85060c0469e4473e672 100644 (file)
@@ -81,11 +81,10 @@ func NewStdNetBind() Bind {
 type StdNetEndpoint struct {
        // AddrPort is the endpoint destination.
        netip.AddrPort
-       // src is the current sticky source address and interface index, if supported.
-       src struct {
-               netip.Addr
-               ifidx int32
-       }
+       // src is the current sticky source address and interface index, if
+       // supported. Typically this is a PKTINFO structure from/for control
+       // messages, see unix.PKTINFO for an example.
+       src []byte
 }
 
 var (
@@ -104,21 +103,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
 }
 
 func (e *StdNetEndpoint) ClearSrc() {
-       e.src.ifidx = 0
-       e.src.Addr = netip.Addr{}
+       if e.src != nil {
+               // Truncate src, no need to reallocate.
+               e.src = e.src[:0]
+       }
 }
 
 func (e *StdNetEndpoint) DstIP() netip.Addr {
        return e.AddrPort.Addr()
 }
 
-func (e *StdNetEndpoint) SrcIP() netip.Addr {
-       return e.src.Addr
-}
-
-func (e *StdNetEndpoint) SrcIfidx() int32 {
-       return e.src.ifidx
-}
+// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
 
 func (e *StdNetEndpoint) DstToBytes() []byte {
        b, _ := e.AddrPort.MarshalBinary()
@@ -129,10 +124,6 @@ func (e *StdNetEndpoint) DstToString() string {
        return e.AddrPort.String()
 }
 
-func (e *StdNetEndpoint) SrcToString() string {
-       return e.src.Addr.String()
-}
-
 func listenNet(network string, port int) (*net.UDPConn, int, error) {
        conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
        if err != nil {
index 05f00ea5b08f6150948ade79b207077c811c0bb4..1fa8a0c4bb72dac43a47746045c2d33356c2b085 100644 (file)
@@ -7,6 +7,20 @@
 
 package conn
 
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+       return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+       return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+       return ""
+}
+
 // TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
 // use alternatively named flags and need ports and require testing.
 
index 274fa38a12b25fdd88ef996cb2d406c020cb4a23..a30ccc715c38da248ac40a53c942d61b1628b9ed 100644 (file)
@@ -14,6 +14,37 @@ import (
        "golang.org/x/sys/unix"
 )
 
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+       switch len(e.src) {
+       case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+               return netip.AddrFrom4(info.Spec_dst)
+       case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+               info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+               // TODO: set zone. in order to do so we need to check if the address is
+               // link local, and if it is perform a syscall to turn the ifindex into a
+               // zone string because netip uses string zones.
+               return netip.AddrFrom16(info.Addr)
+       }
+       return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+       switch len(e.src) {
+       case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+               return info.Ifindex
+       case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+               info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+               return int32(info.Ifindex)
+       }
+       return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+       return e.SrcIP().String()
+}
+
 // getSrcFromControl parses the control for PKTINFO and if found updates ep with
 // the source information found.
 func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -35,81 +66,43 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
                if hdr.Level == unix.IPPROTO_IP &&
                        hdr.Type == unix.IP_PKTINFO {
 
-                       info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
-                       ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
-                       ep.src.ifidx = info.Ifindex
+                       if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+                               ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+                       }
+                       ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
 
+                       hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+                       copy(ep.src, hdrBuf)
+                       copy(ep.src[unix.CmsgLen(0):], data)
                        return
                }
 
                if hdr.Level == unix.IPPROTO_IPV6 &&
                        hdr.Type == unix.IPV6_PKTINFO {
 
-                       info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
-                       ep.src.Addr = netip.AddrFrom16(info.Addr)
-                       ep.src.ifidx = int32(info.Ifindex)
+                       if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+                               ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+                       }
+
+                       ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
 
+                       hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+                       copy(ep.src, hdrBuf)
+                       copy(ep.src[unix.CmsgLen(0):], data)
                        return
                }
        }
 }
 
-// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
-// panics if buf is of insufficient size.
-func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
-       size := int(unsafe.Sizeof(t))
-       if len(buf) < size {
-               panic("pktInfoFromBuf: buffer too small")
-       }
-       copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
-       return t
-}
-
 // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
 // and source ifindex found in ep. control's len will be set to 0 in the event
 // that ep is a default value.
 func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
-       *control = (*control)[:cap(*control)]
-       if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
-               *control = (*control)[:0]
+       if cap(*control) < len(ep.src) {
                return
        }
-
-       if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
-               *control = (*control)[:0]
-               return
-       }
-
-       if len(*control) < srcControlSize {
-               *control = (*control)[:0]
-               return
-       }
-
-       hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
-       if ep.SrcIP().Is4() {
-               hdr.Level = unix.IPPROTO_IP
-               hdr.Type = unix.IP_PKTINFO
-               hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
-
-               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
-               info.Ifindex = ep.src.ifidx
-               if ep.SrcIP().IsValid() {
-                       info.Spec_dst = ep.SrcIP().As4()
-               }
-               *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
-       } else {
-               hdr.Level = unix.IPPROTO_IPV6
-               hdr.Type = unix.IPV6_PKTINFO
-               hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
-
-               info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
-               info.Ifindex = uint32(ep.src.ifidx)
-               if ep.SrcIP().IsValid() {
-                       info.Addr = ep.SrcIP().As16()
-               }
-               *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
-       }
-
+       *control = (*control)[:0]
+       *control = append(*control, ep.src...)
 }
 
 var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
index 0219ac300811862709db5c8873c8a5f9c137ca7c..679213a82cdef0da8e4d02f5ac5b2e7dc1c16635 100644 (file)
@@ -18,13 +18,47 @@ import (
        "golang.org/x/sys/unix"
 )
 
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+       var buf []byte
+       if addr.Is4() {
+               buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+               hdr := unix.Cmsghdr{
+                       Level: unix.IPPROTO_IP,
+                       Type:  unix.IP_PKTINFO,
+               }
+               hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+               copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+               info := unix.Inet4Pktinfo{
+                       Ifindex:  ifidx,
+                       Spec_dst: addr.As4(),
+               }
+               copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+       } else {
+               buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+               hdr := unix.Cmsghdr{
+                       Level: unix.IPPROTO_IPV6,
+                       Type:  unix.IPV6_PKTINFO,
+               }
+               hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+               copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+               info := unix.Inet6Pktinfo{
+                       Ifindex: uint32(ifidx),
+                       Addr:    addr.As16(),
+               }
+               copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+       }
+
+       ep.src = buf
+}
+
 func Test_setSrcControl(t *testing.T) {
        t.Run("IPv4", func(t *testing.T) {
                ep := &StdNetEndpoint{
                        AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
                }
-               ep.src.Addr = netip.MustParseAddr("127.0.0.1")
-               ep.src.ifidx = 5
+               setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
 
                control := make([]byte, srcControlSize)
 
@@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) {
                ep := &StdNetEndpoint{
                        AddrPort: netip.MustParseAddrPort("[::1]:1234"),
                }
-               ep.src.Addr = netip.MustParseAddr("::1")
-               ep.src.ifidx = 5
+               setSrc(ep, netip.MustParseAddr("::1"), 5)
 
                control := make([]byte, srcControlSize)
 
@@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
        })
 
        t.Run("ClearOnNoSrc", func(t *testing.T) {
-               control := make([]byte, srcControlSize)
+               control := make([]byte, unix.CmsgLen(0))
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = 1
                hdr.Type = 2
@@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
 
 func Test_getSrcFromControl(t *testing.T) {
        t.Run("IPv4", func(t *testing.T) {
-               control := make([]byte, srcControlSize)
+               control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = unix.IPPROTO_IP
                hdr.Type = unix.IP_PKTINFO
@@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) {
                ep := &StdNetEndpoint{}
                getSrcFromControl(control, ep)
 
-               if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
-                       t.Errorf("unexpected address: %v", ep.src.Addr)
+               if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+                       t.Errorf("unexpected address: %v", ep.SrcIP())
                }
-               if ep.src.ifidx != 5 {
-                       t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+               if ep.SrcIfidx() != 5 {
+                       t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
                }
        })
        t.Run("IPv6", func(t *testing.T) {
-               control := make([]byte, srcControlSize)
+               control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = unix.IPPROTO_IPV6
                hdr.Type = unix.IPV6_PKTINFO
@@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) {
                if ep.SrcIP() != netip.MustParseAddr("::1") {
                        t.Errorf("unexpected address: %v", ep.SrcIP())
                }
-               if ep.src.ifidx != 5 {
-                       t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+               if ep.SrcIfidx() != 5 {
+                       t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
                }
        })
        t.Run("ClearOnEmpty", func(t *testing.T) {
-               control := make([]byte, srcControlSize)
+               var control []byte
                ep := &StdNetEndpoint{}
-               ep.src.Addr = netip.MustParseAddr("::1")
-               ep.src.ifidx = 5
+               setSrc(ep, netip.MustParseAddr("::1"), 5)
 
                getSrcFromControl(control, ep)
                if ep.SrcIP().IsValid() {
-                       t.Errorf("unexpected address: %v", ep.src.Addr)
+                       t.Errorf("unexpected address: %v", ep.SrcIP())
                }
-               if ep.src.ifidx != 0 {
-                       t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+               if ep.SrcIfidx() != 0 {
+                       t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
                }
        })
        t.Run("Multiple", func(t *testing.T) {
@@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) {
                zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
                zeroHdr.SetLen(unix.CmsgLen(0))
 
-               control := make([]byte, srcControlSize)
+               control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = unix.IPPROTO_IP
                hdr.Type = unix.IP_PKTINFO
@@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) {
                ep := &StdNetEndpoint{}
                getSrcFromControl(combined, ep)
 
-               if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
-                       t.Errorf("unexpected address: %v", ep.src.Addr)
+               if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+                       t.Errorf("unexpected address: %v", ep.SrcIP())
                }
-               if ep.src.ifidx != 5 {
-                       t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+               if ep.SrcIfidx() != 5 {
+                       t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
                }
        })
 }