]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn, device: use UDP GSO and GRO on Linux
authorJordan Whited <jordan@tailscale.com>
Mon, 2 Oct 2023 20:53:07 +0000 (13:53 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Oct 2023 13:07:36 +0000 (15:07 +0200)
StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is
dependent on checksum offload support on the egress netdev. UDP GSO
will be disabled in the event sendmmsg() returns EIO, which is a strong
signal that the egress netdev does not support checksum offload.

The iperf3 results below demonstrate the effect of this commit between
two Linux computers with i5-12400 CPUs. There is roughly ~13us of round
trip latency between them.

The first result is from commit 052af4a without UDP GSO or GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139   3.01 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  9.85 GBytes  8.46 Gbits/sec  1139  sender
[  5]   0.00-10.04  sec  9.85 GBytes  8.42 Gbits/sec        receiver

The second result is with UDP GSO and GRO.

Starting Test: protocol: TCP, 1 streams, 131072 byte blocks
[ ID] Interval           Transfer     Bitrate         Retr  Cwnd
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   3.15 MBytes
- - - - - - - - - - - - - - - - - - - - - - - - -
Test Complete. Summary Results:
[ ID] Interval           Transfer     Bitrate         Retr
[  5]   0.00-10.00  sec  12.3 GBytes  10.6 Gbits/sec  232   sender
[  5]   0.00-10.04  sec  12.3 GBytes  10.6 Gbits/sec        receiver

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
13 files changed:
conn/bind_std.go
conn/bind_std_test.go
conn/control_default.go [moved from conn/sticky_default.go with 56% similarity]
conn/control_linux.go [moved from conn/sticky_linux.go with 65% similarity]
conn/control_linux_test.go [moved from conn/sticky_linux_test.go with 96% similarity]
conn/controlfns_linux.go
conn/errors_default.go [new file with mode: 0644]
conn/errors_linux.go [new file with mode: 0644]
conn/features_default.go [new file with mode: 0644]
conn/features_linux.go [new file with mode: 0644]
device/send.go
go.mod
go.sum

index c701ef8724007bda09a7f85060c0469e4473e672..9886c91f99a38358a554182a0f5ee38ee2ab4015 100644 (file)
@@ -8,6 +8,7 @@ package conn
 import (
        "context"
        "errors"
+       "fmt"
        "net"
        "net/netip"
        "runtime"
@@ -29,16 +30,19 @@ var (
 // methods for sending and receiving multiple datagrams per-syscall. See the
 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
 type StdNetBind struct {
-       mu     sync.Mutex // protects all fields except as specified
-       ipv4   *net.UDPConn
-       ipv6   *net.UDPConn
-       ipv4PC *ipv4.PacketConn // will be nil on non-Linux
-       ipv6PC *ipv6.PacketConn // will be nil on non-Linux
-
-       // these three fields are not guarded by mu
-       udpAddrPool  sync.Pool
-       ipv4MsgsPool sync.Pool
-       ipv6MsgsPool sync.Pool
+       mu            sync.Mutex // protects all fields except as specified
+       ipv4          *net.UDPConn
+       ipv6          *net.UDPConn
+       ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
+       ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
+       ipv4TxOffload bool
+       ipv4RxOffload bool
+       ipv6TxOffload bool
+       ipv6RxOffload bool
+
+       // these two fields are not guarded by mu
+       udpAddrPool sync.Pool
+       msgsPool    sync.Pool
 
        blackhole4 bool
        blackhole6 bool
@@ -54,23 +58,14 @@ func NewStdNetBind() Bind {
                        },
                },
 
-               ipv4MsgsPool: sync.Pool{
-                       New: func() any {
-                               msgs := make([]ipv4.Message, IdealBatchSize)
-                               for i := range msgs {
-                                       msgs[i].Buffers = make(net.Buffers, 1)
-                                       msgs[i].OOB = make([]byte, srcControlSize)
-                               }
-                               return &msgs
-                       },
-               },
-
-               ipv6MsgsPool: sync.Pool{
+               msgsPool: sync.Pool{
                        New: func() any {
+                               // ipv6.Message and ipv4.Message are interchangeable as they are
+                               // both aliases for x/net/internal/socket.Message.
                                msgs := make([]ipv6.Message, IdealBatchSize)
                                for i := range msgs {
                                        msgs[i].Buffers = make(net.Buffers, 1)
-                                       msgs[i].OOB = make([]byte, srcControlSize)
+                                       msgs[i].OOB = make([]byte, controlSize)
                                }
                                return &msgs
                        },
@@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr {
        return e.AddrPort.Addr()
 }
 
-// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
 
 func (e *StdNetEndpoint) DstToBytes() []byte {
        b, _ := e.AddrPort.MarshalBinary()
@@ -179,19 +174,21 @@ again:
        }
        var fns []ReceiveFunc
        if v4conn != nil {
+               s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
                if runtime.GOOS == "linux" {
                        v4pc = ipv4.NewPacketConn(v4conn)
                        s.ipv4PC = v4pc
                }
-               fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
+               fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
                s.ipv4 = v4conn
        }
        if v6conn != nil {
+               s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
                if runtime.GOOS == "linux" {
                        v6pc = ipv6.NewPacketConn(v6conn)
                        s.ipv6PC = v6pc
                }
-               fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
+               fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
                s.ipv6 = v6conn
        }
        if len(fns) == 0 {
@@ -201,69 +198,93 @@ again:
        return fns, uint16(port), nil
 }
 
-func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
-       return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-               msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-               defer s.ipv4MsgsPool.Put(msgs)
-               for i := range bufs {
-                       (*msgs)[i].Buffers[0] = bufs[i]
-               }
-               var numMsgs int
-               if runtime.GOOS == "linux" {
-                       numMsgs, err = pc.ReadBatch(*msgs, 0)
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+       for i := range *msgs {
+               (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+       }
+       s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+       return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+       // If compilation fails here these are no longer the same underlying type.
+       _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+       ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+       WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+       br batchReader,
+       conn *net.UDPConn,
+       rxOffload bool,
+       bufs [][]byte,
+       sizes []int,
+       eps []Endpoint,
+) (n int, err error) {
+       msgs := s.getMessages()
+       for i := range bufs {
+               (*msgs)[i].Buffers[0] = bufs[i]
+               (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+       }
+       defer s.putMessages(msgs)
+       var numMsgs int
+       if runtime.GOOS == "linux" {
+               if rxOffload {
+                       readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+                       numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+                       if err != nil {
+                               return 0, err
+                       }
+                       numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
                        if err != nil {
                                return 0, err
                        }
                } else {
-                       msg := &(*msgs)[0]
-                       msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+                       numMsgs, err = br.ReadBatch(*msgs, 0)
                        if err != nil {
                                return 0, err
                        }
-                       numMsgs = 1
                }
-               for i := 0; i < numMsgs; i++ {
-                       msg := &(*msgs)[i]
-                       sizes[i] = msg.N
-                       addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-                       ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
-                       getSrcFromControl(msg.OOB[:msg.NN], ep)
-                       eps[i] = ep
+       } else {
+               msg := &(*msgs)[0]
+               msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+               if err != nil {
+                       return 0, err
+               }
+               numMsgs = 1
+       }
+       for i := 0; i < numMsgs; i++ {
+               msg := &(*msgs)[i]
+               sizes[i] = msg.N
+               if sizes[i] == 0 {
+                       continue
                }
-               return numMsgs, nil
+               addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+               ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+               getSrcFromControl(msg.OOB[:msg.NN], ep)
+               eps[i] = ep
        }
+       return numMsgs, nil
 }
 
-func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
        return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-               msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-               defer s.ipv6MsgsPool.Put(msgs)
-               for i := range bufs {
-                       (*msgs)[i].Buffers[0] = bufs[i]
-               }
-               var numMsgs int
-               if runtime.GOOS == "linux" {
-                       numMsgs, err = pc.ReadBatch(*msgs, 0)
-                       if err != nil {
-                               return 0, err
-                       }
-               } else {
-                       msg := &(*msgs)[0]
-                       msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
-                       if err != nil {
-                               return 0, err
-                       }
-                       numMsgs = 1
-               }
-               for i := 0; i < numMsgs; i++ {
-                       msg := &(*msgs)[i]
-                       sizes[i] = msg.N
-                       addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-                       ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
-                       getSrcFromControl(msg.OOB[:msg.NN], ep)
-                       eps[i] = ep
-               }
-               return numMsgs, nil
+               return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+       }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+       return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+               return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
        }
 }
 
@@ -293,28 +314,42 @@ func (s *StdNetBind) Close() error {
        }
        s.blackhole4 = false
        s.blackhole6 = false
+       s.ipv4TxOffload = false
+       s.ipv4RxOffload = false
+       s.ipv6TxOffload = false
+       s.ipv6RxOffload = false
        if err1 != nil {
                return err1
        }
        return err2
 }
 
+type ErrUDPGSODisabled struct {
+       onLaddr  string
+       RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+       return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+       return e.RetryErr
+}
+
 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
        s.mu.Lock()
        blackhole := s.blackhole4
        conn := s.ipv4
-       var (
-               pc4 *ipv4.PacketConn
-               pc6 *ipv6.PacketConn
-       )
+       offload := s.ipv4TxOffload
+       br := batchWriter(s.ipv4PC)
        is6 := false
        if endpoint.DstIP().Is6() {
                blackhole = s.blackhole6
                conn = s.ipv6
-               pc6 = s.ipv6PC
+               br = s.ipv6PC
                is6 = true
-       } else {
-               pc4 = s.ipv4PC
+               offload = s.ipv6TxOffload
        }
        s.mu.Unlock()
 
@@ -324,25 +359,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
        if conn == nil {
                return syscall.EAFNOSUPPORT
        }
+
+       msgs := s.getMessages()
+       defer s.putMessages(msgs)
+       ua := s.udpAddrPool.Get().(*net.UDPAddr)
+       defer s.udpAddrPool.Put(ua)
        if is6 {
-               return s.send6(conn, pc6, endpoint, bufs)
+               as16 := endpoint.DstIP().As16()
+               copy(ua.IP, as16[:])
+               ua.IP = ua.IP[:16]
        } else {
-               return s.send4(conn, pc4, endpoint, bufs)
+               as4 := endpoint.DstIP().As4()
+               copy(ua.IP, as4[:])
+               ua.IP = ua.IP[:4]
        }
+       ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+       var (
+               retried bool
+               err     error
+       )
+retry:
+       if offload {
+               n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+               err = s.send(conn, br, (*msgs)[:n])
+               if err != nil && offload && errShouldDisableUDPGSO(err) {
+                       offload = false
+                       s.mu.Lock()
+                       if is6 {
+                               s.ipv6TxOffload = false
+                       } else {
+                               s.ipv4TxOffload = false
+                       }
+                       s.mu.Unlock()
+                       retried = true
+                       goto retry
+               }
+       } else {
+               for i := range bufs {
+                       (*msgs)[i].Addr = ua
+                       (*msgs)[i].Buffers[0] = bufs[i]
+                       setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+               }
+               err = s.send(conn, br, (*msgs)[:len(bufs)])
+       }
+       if retried {
+               return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+       }
+       return err
 }
 
-func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
-       ua := s.udpAddrPool.Get().(*net.UDPAddr)
-       as4 := ep.DstIP().As4()
-       copy(ua.IP, as4[:])
-       ua.IP = ua.IP[:4]
-       ua.Port = int(ep.(*StdNetEndpoint).Port())
-       msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-       for i, buf := range bufs {
-               (*msgs)[i].Buffers[0] = buf
-               (*msgs)[i].Addr = ua
-               setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
-       }
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
        var (
                n     int
                err   error
@@ -350,59 +416,128 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint,
        )
        if runtime.GOOS == "linux" {
                for {
-                       n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
-                       if err != nil || n == len((*msgs)[start:len(bufs)]) {
+                       n, err = pc.WriteBatch(msgs[start:], 0)
+                       if err != nil || n == len(msgs[start:]) {
                                break
                        }
                        start += n
                }
        } else {
-               for i, buf := range bufs {
-                       _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
+               for _, msg := range msgs {
+                       _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
                        if err != nil {
                                break
                        }
                }
        }
-       s.udpAddrPool.Put(ua)
-       s.ipv4MsgsPool.Put(msgs)
        return err
 }
 
-func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
-       ua := s.udpAddrPool.Get().(*net.UDPAddr)
-       as16 := ep.DstIP().As16()
-       copy(ua.IP, as16[:])
-       ua.IP = ua.IP[:16]
-       ua.Port = int(ep.(*StdNetEndpoint).Port())
-       msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-       for i, buf := range bufs {
-               (*msgs)[i].Buffers[0] = buf
-               (*msgs)[i].Addr = ua
-               setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
-       }
+const (
+       // Exceeding these values results in EMSGSIZE. They account for layer3 and
+       // layer4 headers. IPv6 does not need to account for itself as the payload
+       // length field is self excluding.
+       maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+       maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+       // This is a hard limit imposed by the kernel.
+       udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
        var (
-               n     int
-               err   error
-               start int
+               base     = -1 // index of msg we are currently coalescing into
+               gsoSize  int  // segmentation size of msgs[base]
+               dgramCnt int  // number of dgrams coalesced into msgs[base]
+               endBatch bool // tracking flag to start a new batch on next iteration of bufs
        )
-       if runtime.GOOS == "linux" {
-               for {
-                       n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
-                       if err != nil || n == len((*msgs)[start:len(bufs)]) {
-                               break
+       maxPayloadLen := maxIPv4PayloadLen
+       if ep.DstIP().Is6() {
+               maxPayloadLen = maxIPv6PayloadLen
+       }
+       for i, buf := range bufs {
+               if i > 0 {
+                       msgLen := len(buf)
+                       baseLenBefore := len(msgs[base].Buffers[0])
+                       freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+                       if msgLen+baseLenBefore <= maxPayloadLen &&
+                               msgLen <= gsoSize &&
+                               msgLen <= freeBaseCap &&
+                               dgramCnt < udpSegmentMaxDatagrams &&
+                               !endBatch {
+                               msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+                               if i == len(bufs)-1 {
+                                       setGSO(&msgs[base].OOB, uint16(gsoSize))
+                               }
+                               dgramCnt++
+                               if msgLen < gsoSize {
+                                       // A smaller than gsoSize packet on the tail is legal, but
+                                       // it must end the batch.
+                                       endBatch = true
+                               }
+                               continue
                        }
-                       start += n
                }
-       } else {
-               for i, buf := range bufs {
-                       _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
-                       if err != nil {
-                               break
+               if dgramCnt > 1 {
+                       setGSO(&msgs[base].OOB, uint16(gsoSize))
+               }
+               // Reset prior to incrementing base since we are preparing to start a
+               // new potential batch.
+               endBatch = false
+               base++
+               gsoSize = len(buf)
+               setSrcControl(&msgs[base].OOB, ep)
+               msgs[base].Buffers[0] = buf
+               msgs[base].Addr = addr
+               dgramCnt = 1
+       }
+       return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+       for i := firstMsgAt; i < len(msgs); i++ {
+               msg := &msgs[i]
+               if msg.N == 0 {
+                       return n, err
+               }
+               var (
+                       gsoSize    int
+                       start      int
+                       end        = msg.N
+                       numToSplit = 1
+               )
+               gsoSize, err = getGSO(msg.OOB[:msg.NN])
+               if err != nil {
+                       return n, err
+               }
+               if gsoSize > 0 {
+                       numToSplit = (msg.N + gsoSize - 1) / gsoSize
+                       end = gsoSize
+               }
+               for j := 0; j < numToSplit; j++ {
+                       if n > i {
+                               return n, errors.New("splitting coalesced packet resulted in overflow")
                        }
+                       copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+                       msgs[n].N = copied
+                       msgs[n].Addr = msg.Addr
+                       start = end
+                       end += gsoSize
+                       if end > msg.N {
+                               end = msg.N
+                       }
+                       n++
+               }
+               if i != n-1 {
+                       // It is legal for bytes to move within msg.Buffers[0] as a result
+                       // of splitting, so we only zero the source msg len when it is not
+                       // the destination of the last split operation above.
+                       msg.N = 0
                }
        }
-       s.udpAddrPool.Put(ua)
-       s.ipv6MsgsPool.Put(msgs)
-       return err
+       return n, nil
 }
index 1e4677654ce315ff2b2d33666a0ac949ab420122..34a3c9acfb37090a485d5f0384a76923c574dbec 100644 (file)
@@ -1,6 +1,12 @@
 package conn
 
-import "testing"
+import (
+       "encoding/binary"
+       "net"
+       "testing"
+
+       "golang.org/x/net/ipv6"
+)
 
 func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
        bind := NewStdNetBind().(*StdNetBind)
@@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
                fn(bufs, sizes, eps)
        }
 }
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+       *control = (*control)[:cap(*control)]
+       binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+       cases := []struct {
+               name     string
+               buffs    [][]byte
+               wantLens []int
+               wantGSO  []int
+       }{
+               {
+                       name: "one message no coalesce",
+                       buffs: [][]byte{
+                               make([]byte, 1, 1),
+                       },
+                       wantLens: []int{1},
+                       wantGSO:  []int{0},
+               },
+               {
+                       name: "two messages equal len coalesce",
+                       buffs: [][]byte{
+                               make([]byte, 1, 2),
+                               make([]byte, 1, 1),
+                       },
+                       wantLens: []int{2},
+                       wantGSO:  []int{1},
+               },
+               {
+                       name: "two messages unequal len coalesce",
+                       buffs: [][]byte{
+                               make([]byte, 2, 3),
+                               make([]byte, 1, 1),
+                       },
+                       wantLens: []int{3},
+                       wantGSO:  []int{2},
+               },
+               {
+                       name: "three messages second unequal len coalesce",
+                       buffs: [][]byte{
+                               make([]byte, 2, 3),
+                               make([]byte, 1, 1),
+                               make([]byte, 2, 2),
+                       },
+                       wantLens: []int{3, 2},
+                       wantGSO:  []int{2, 0},
+               },
+               {
+                       name: "three messages limited cap coalesce",
+                       buffs: [][]byte{
+                               make([]byte, 2, 4),
+                               make([]byte, 2, 2),
+                               make([]byte, 2, 2),
+                       },
+                       wantLens: []int{4, 2},
+                       wantGSO:  []int{2, 0},
+               },
+       }
+
+       for _, tt := range cases {
+               t.Run(tt.name, func(t *testing.T) {
+                       addr := &net.UDPAddr{
+                               IP:   net.ParseIP("127.0.0.1").To4(),
+                               Port: 1,
+                       }
+                       msgs := make([]ipv6.Message, len(tt.buffs))
+                       for i := range msgs {
+                               msgs[i].Buffers = make([][]byte, 1)
+                               msgs[i].OOB = make([]byte, 0, 2)
+                       }
+                       got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+                       if got != len(tt.wantLens) {
+                               t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+                       }
+                       for i := 0; i < got; i++ {
+                               if msgs[i].Addr != addr {
+                                       t.Errorf("msgs[%d].Addr != passed addr", i)
+                               }
+                               gotLen := len(msgs[i].Buffers[0])
+                               if gotLen != tt.wantLens[i] {
+                                       t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+                               }
+                               gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+                               if err != nil {
+                                       t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+                               }
+                               if gotGSO != tt.wantGSO[i] {
+                                       t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+                               }
+                       }
+               })
+       }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+       if len(control) < 2 {
+               return 0, nil
+       }
+       return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+       newMsg := func(n, gso int) ipv6.Message {
+               msg := ipv6.Message{
+                       Buffers: [][]byte{make([]byte, 1<<16-1)},
+                       N:       n,
+                       OOB:     make([]byte, 2),
+               }
+               binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+               if gso > 0 {
+                       msg.NN = 2
+               }
+               return msg
+       }
+
+       cases := []struct {
+               name        string
+               msgs        []ipv6.Message
+               firstMsgAt  int
+               wantNumEval int
+               wantMsgLens []int
+               wantErr     bool
+       }{
+               {
+                       name: "second last split last empty",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(3, 1),
+                               newMsg(0, 0),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 3,
+                       wantMsgLens: []int{1, 1, 1, 0},
+                       wantErr:     false,
+               },
+               {
+                       name: "second last no split last empty",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(1, 0),
+                               newMsg(0, 0),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 1,
+                       wantMsgLens: []int{1, 0, 0, 0},
+                       wantErr:     false,
+               },
+               {
+                       name: "second last no split last no split",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(1, 0),
+                               newMsg(1, 0),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 2,
+                       wantMsgLens: []int{1, 1, 0, 0},
+                       wantErr:     false,
+               },
+               {
+                       name: "second last no split last split",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(1, 0),
+                               newMsg(3, 1),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 4,
+                       wantMsgLens: []int{1, 1, 1, 1},
+                       wantErr:     false,
+               },
+               {
+                       name: "second last split last split",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(2, 1),
+                               newMsg(2, 1),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 4,
+                       wantMsgLens: []int{1, 1, 1, 1},
+                       wantErr:     false,
+               },
+               {
+                       name: "second last no split last split overflow",
+                       msgs: []ipv6.Message{
+                               newMsg(0, 0),
+                               newMsg(0, 0),
+                               newMsg(1, 0),
+                               newMsg(4, 1),
+                       },
+                       firstMsgAt:  2,
+                       wantNumEval: 4,
+                       wantMsgLens: []int{1, 1, 1, 1},
+                       wantErr:     true,
+               },
+       }
+
+       for _, tt := range cases {
+               t.Run(tt.name, func(t *testing.T) {
+                       got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+                       if err != nil && !tt.wantErr {
+                               t.Fatalf("err: %v", err)
+                       }
+                       if got != tt.wantNumEval {
+                               t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+                       }
+                       for i, msg := range tt.msgs {
+                               if msg.N != tt.wantMsgLens[i] {
+                                       t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+                               }
+                       }
+               })
+       }
+}
similarity index 56%
rename from conn/sticky_default.go
rename to conn/control_default.go
index 1fa8a0c4bb72dac43a47746045c2d33356c2b085..9459da555067546e13040dded1d0139118cc441c 100644 (file)
@@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string {
        return ""
 }
 
-// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
-// use alternatively named flags and need ports and require testing.
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
 
 // getSrcFromControl parses the control for PKTINFO and if found updates ep with
 // the source information found.
@@ -34,8 +35,17 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
 func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
 }
 
-// srcControlSize returns the recommended buffer size for pooling sticky control
-// data.
-const srcControlSize = 0
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+       return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// controlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const controlSize = 0
 
 const StdNetSupportsStickySockets = false
similarity index 65%
rename from conn/sticky_linux.go
rename to conn/control_linux.go
index a30ccc715c38da248ac40a53c942d61b1628b9ed..44a94e67091e9edb68d1f65d30074f49f4db6d81 100644 (file)
@@ -8,6 +8,7 @@
 package conn
 
 import (
+       "fmt"
        "net/netip"
        "unsafe"
 
@@ -105,6 +106,54 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
        *control = append(*control, ep.src...)
 }
 
-var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+const (
+       sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+       var (
+               hdr  unix.Cmsghdr
+               data []byte
+               rem  = control
+               err  error
+       )
+
+       for len(rem) > unix.SizeofCmsghdr {
+               hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+               if err != nil {
+                       return 0, fmt.Errorf("error parsing socket control message: %w", err)
+               }
+               if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+                       var gso uint16
+                       copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+                       return int(gso), nil
+               }
+       }
+       return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+       existingLen := len(*control)
+       avail := cap(*control) - existingLen
+       space := unix.CmsgSpace(sizeOfGSOData)
+       if avail < space {
+               return
+       }
+       *control = (*control)[:cap(*control)]
+       gsoControl := (*control)[existingLen:]
+       hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+       hdr.Level = unix.SOL_UDP
+       hdr.Type = unix.UDP_SEGMENT
+       hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+       copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+       *control = (*control)[:existingLen+space]
+}
+
+// controlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData)
 
 const StdNetSupportsStickySockets = true
similarity index 96%
rename from conn/sticky_linux_test.go
rename to conn/control_linux_test.go
index 679213a82cdef0da8e4d02f5ac5b2e7dc1c16635..96f9da2e004c816598076a367f8f460101b360e2 100644 (file)
@@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) {
                }
                setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
 
-               control := make([]byte, srcControlSize)
+               control := make([]byte, controlSize)
 
                setSrcControl(&control, ep)
 
@@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) {
                }
                setSrc(ep, netip.MustParseAddr("::1"), 5)
 
-               control := make([]byte, srcControlSize)
+               control := make([]byte, controlSize)
 
                setSrcControl(&control, ep)
 
@@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
        })
 
        t.Run("ClearOnNoSrc", func(t *testing.T) {
-               control := make([]byte, unix.CmsgLen(0))
+               control := make([]byte, controlSize)
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = 1
                hdr.Type = 2
@@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
 
 func Test_getSrcFromControl(t *testing.T) {
        t.Run("IPv4", func(t *testing.T) {
-               control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+               control := make([]byte, controlSize)
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = unix.IPPROTO_IP
                hdr.Type = unix.IP_PKTINFO
@@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) {
                }
        })
        t.Run("IPv6", func(t *testing.T) {
-               control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+               control := make([]byte, controlSize)
                hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
                hdr.Level = unix.IPPROTO_IPV6
                hdr.Type = unix.IPV6_PKTINFO
index a2396fe899c6bbeef424d1dd3e6fb72160a0fbdc..f6ab1d2ec45ce8d1625aa5b8ebe075aa294b96b2 100644 (file)
@@ -57,5 +57,13 @@ func init() {
                        }
                        return err
                },
+
+               // Attempt to enable UDP_GRO
+               func(network, address string, c syscall.RawConn) error {
+                       c.Control(func(fd uintptr) {
+                               _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+                       })
+                       return nil
+               },
        )
 }
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644 (file)
index 0000000..f1e5b90
--- /dev/null
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+       return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644 (file)
index 0000000..8e61000
--- /dev/null
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+       "errors"
+       "os"
+
+       "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+       var serr *os.SyscallError
+       if errors.As(err, &serr) {
+               // EIO is returned by udp_send_skb() if the device driver does not have
+               // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+               // See:
+               // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+               // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+               return serr.Err == unix.EIO
+       }
+       return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644 (file)
index 0000000..d53ff5f
--- /dev/null
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+       return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644 (file)
index 0000000..e1fb57f
--- /dev/null
@@ -0,0 +1,35 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+       "net"
+
+       "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+       rc, err := conn.SyscallConn()
+       if err != nil {
+               return
+       }
+       err = rc.Control(func(fd uintptr) {
+               _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+               if errSyscall != nil {
+                       return
+               }
+               txOffload = true
+               opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+               if errSyscall != nil {
+                       return
+               }
+               rxOffload = opt == 1
+       })
+       if err != nil {
+               return false, false
+       }
+       return txOffload, rxOffload
+}
index d22bf264e9f2e43325a701d74f3579878d0c705b..cd8a2a0ddf0bad8d8ae74e154c45ebdf7ac0dfb0 100644 (file)
@@ -17,6 +17,7 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
+       "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/tun"
 )
 
@@ -525,6 +526,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
                        device.PutOutboundElement(elem)
                }
                device.PutOutboundElementsSlice(elems)
+               if err != nil {
+                       var errGSO conn.ErrUDPGSODisabled
+                       if errors.As(err, &errGSO) {
+                               device.log.Verbosef(err.Error())
+                               err = errGSO.RetryErr
+                       }
+               }
                if err != nil {
                        device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
                        continue
diff --git a/go.mod b/go.mod
index c04e1bb61daf216ce1c0de455a3ff0f5b3c5a9d8..758dcde7300e5fefb0f7f76011c2055b448836b5 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -5,7 +5,7 @@ go 1.20
 require (
        golang.org/x/crypto v0.6.0
        golang.org/x/net v0.7.0
-       golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
+       golang.org/x/sys v0.12.0
        golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
        gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
 )
diff --git a/go.sum b/go.sum
index cfeaee6235b6207ad28aafde0882931faaead45b..fe4ca7e773353a108b56bb5dce08d3f48d73f5f2 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -4,8 +4,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
 golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
 golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
 golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
-golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
-golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=