]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: implement UDP GSO/GRO for Linux
authorJordan Whited <jordan@tailscale.com>
Wed, 1 Nov 2023 02:53:35 +0000 (19:53 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Dec 2023 15:27:22 +0000 (16:27 +0100)
Implement UDP GSO and GRO for the Linux tun.Device, which is made
possible by virtio extensions in the kernel's TUN driver starting in
v6.2.

secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is
used to demonstrate the effect of this commit between two Linux
computers with i5-12400 CPUs. There is roughly ~13us of round trip
latency between them. secnetperf was invoked with the following command
line options:
-stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0

The first result is from commit 2e0774f without UDP GSO/GRO on the TUN.

[conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us
SendTotalPackets=55859 SendSuspectedLostPackets=61
SendSpuriousLostPackets=59 SendCongestionCount=27
SendEcnCongestionCount=0 RecvTotalPackets=2779122
RecvReorderedPackets=0 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms).

The second result is with UDP GSO/GRO on the TUN.

[conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us
SendTotalPackets=165033 SendSuspectedLostPackets=64
SendSpuriousLostPackets=61 SendCongestionCount=53
SendEcnCongestionCount=0 RecvTotalPackets=11845268
RecvReorderedPackets=25267 RecvDroppedPackets=0
RecvDuplicatePackets=0 RecvDecryptionFailures=0
Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms).

Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/offload_linux.go [moved from tun/tcp_offload_linux.go with 50% similarity]
tun/offload_linux_test.go [new file with mode: 0644]
tun/tcp_offload_linux_test.go [deleted file]
tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 [deleted file]
tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d [deleted file]
tun/tun_linux.go

similarity index 50%
rename from tun/tcp_offload_linux.go
rename to tun/offload_linux.go
index 1afd27edfb331c81e78f228709f0796f067482d9..9ff7fea8f98f27ac2197d87dbc4db5399af9aefc 100644 (file)
@@ -57,22 +57,23 @@ const (
        virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
 )
 
-// flowKey represents the key for a flow.
-type flowKey struct {
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
        srcAddr, dstAddr [16]byte
        srcPort, dstPort uint16
        rxAck            uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+       isV6             bool
 }
 
-// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
 type tcpGROTable struct {
-       itemsByFlow map[flowKey][]tcpGROItem
+       itemsByFlow map[tcpFlowKey][]tcpGROItem
        itemsPool   [][]tcpGROItem
 }
 
 func newTCPGROTable() *tcpGROTable {
        t := &tcpGROTable{
-               itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize),
+               itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
                itemsPool:   make([][]tcpGROItem, conn.IdealBatchSize),
        }
        for i := range t.itemsPool {
@@ -81,14 +82,15 @@ func newTCPGROTable() *tcpGROTable {
        return t
 }
 
-func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
-       key := flowKey{}
-       addrSize := dstAddr - srcAddr
-       copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
-       copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+       key := tcpFlowKey{}
+       addrSize := dstAddrOffset - srcAddrOffset
+       copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+       copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
        key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
        key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
        key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+       key.isV6 = addrSize == 16
        return key
 }
 
@@ -96,7 +98,7 @@ func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
 // returning the packets found for the flow, or inserting a new one if none
 // is found.
 func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
-       key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+       key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
        items, ok := t.itemsByFlow[key]
        if ok {
                return items, ok
@@ -108,7 +110,7 @@ func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, t
 
 // insert an item in the table for the provided packet and packet metadata.
 func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
-       key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+       key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
        item := tcpGROItem{
                key:       key,
                bufsIndex: uint16(bufsIndex),
@@ -131,7 +133,7 @@ func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
        items[i] = item
 }
 
-func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
        items, _ := t.itemsByFlow[key]
        items = append(items[:i], items[i+1:]...)
        t.itemsByFlow[key] = items
@@ -140,7 +142,7 @@ func (t *tcpGROTable) deleteAt(key flowKey, i int) {
 // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
 // of a GRO evaluation across a vector of packets.
 type tcpGROItem struct {
-       key       flowKey
+       key       tcpFlowKey
        sentSeq   uint32 // the sequence number
        bufsIndex uint16 // the index into the original bufs slice
        numMerged uint16 // the number of packets merged into this item
@@ -164,6 +166,103 @@ func (t *tcpGROTable) reset() {
        }
 }
 
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+       srcAddr, dstAddr [16]byte
+       srcPort, dstPort uint16
+       isV6             bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+       itemsByFlow map[udpFlowKey][]udpGROItem
+       itemsPool   [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+       u := &udpGROTable{
+               itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+               itemsPool:   make([][]udpGROItem, conn.IdealBatchSize),
+       }
+       for i := range u.itemsPool {
+               u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+       }
+       return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+       key := udpFlowKey{}
+       addrSize := dstAddrOffset - srcAddrOffset
+       copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+       copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+       key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+       key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+       key.isV6 = addrSize == 16
+       return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+       key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+       items, ok := u.itemsByFlow[key]
+       if ok {
+               return items, ok
+       }
+       // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+       u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+       return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+       key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+       item := udpGROItem{
+               key:              key,
+               bufsIndex:        uint16(bufsIndex),
+               gsoSize:          uint16(len(pkt[udphOffset+udphLen:])),
+               iphLen:           uint8(udphOffset),
+               cSumKnownInvalid: cSumKnownInvalid,
+       }
+       items, ok := u.itemsByFlow[key]
+       if !ok {
+               items = u.newItems()
+       }
+       items = append(items, item)
+       u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+       items, _ := u.itemsByFlow[item.key]
+       items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+       key              udpFlowKey
+       bufsIndex        uint16 // the index into the original bufs slice
+       numMerged        uint16 // the number of packets merged into this item
+       gsoSize          uint16 // payload size
+       iphLen           uint8  // ip header len
+       cSumKnownInvalid bool   // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+       var items []udpGROItem
+       items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+       return items
+}
+
+func (u *udpGROTable) reset() {
+       for k, items := range u.itemsByFlow {
+               items = items[:0]
+               u.itemsPool = append(u.itemsPool, items)
+               delete(u.itemsByFlow, k)
+       }
+}
+
 // canCoalesce represents the outcome of checking if two TCP packets are
 // candidates for coalescing.
 type canCoalesce int
@@ -174,6 +273,61 @@ const (
        coalesceAppend      canCoalesce = 1
 )
 
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+       if len(pktA) < 9 || len(pktB) < 9 {
+               return false
+       }
+       if pktA[0]>>4 == 6 {
+               if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+                       // cannot coalesce with unequal Traffic class values
+                       return false
+               }
+               if pktA[7] != pktB[7] {
+                       // cannot coalesce with unequal Hop limit values
+                       return false
+               }
+       } else {
+               if pktA[1] != pktB[1] {
+                       // cannot coalesce with unequal ToS values
+                       return false
+               }
+               if pktA[6]>>5 != pktB[6]>>5 {
+                       // cannot coalesce with unequal DF or reserved bits. MF is checked
+                       // further up the stack.
+                       return false
+               }
+               if pktA[8] != pktB[8] {
+                       // cannot coalesce with unequal TTL values
+                       return false
+               }
+       }
+       return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+       pktTarget := bufs[item.bufsIndex][bufsOffset:]
+       if !ipHeadersCanCoalesce(pkt, pktTarget) {
+               return coalesceUnavailable
+       }
+       if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+               // A smaller than gsoSize packet has been appended previously.
+               // Nothing can come after a smaller packet on the end.
+               return coalesceUnavailable
+       }
+       if gsoSize > item.gsoSize {
+               // We cannot have a larger packet following a smaller one.
+               return coalesceUnavailable
+       }
+       return coalesceAppend
+}
+
 // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
 // described by item. This function makes considerations that match the kernel's
 // GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
@@ -189,29 +343,8 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
                        return coalesceUnavailable
                }
        }
-       if pkt[0]>>4 == 6 {
-               if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
-                       // cannot coalesce with unequal Traffic class values
-                       return coalesceUnavailable
-               }
-               if pkt[7] != pktTarget[7] {
-                       // cannot coalesce with unequal Hop limit values
-                       return coalesceUnavailable
-               }
-       } else {
-               if pkt[1] != pktTarget[1] {
-                       // cannot coalesce with unequal ToS values
-                       return coalesceUnavailable
-               }
-               if pkt[6]>>5 != pktTarget[6]>>5 {
-                       // cannot coalesce with unequal DF or reserved bits. MF is checked
-                       // further up the stack.
-                       return coalesceUnavailable
-               }
-               if pkt[8] != pktTarget[8] {
-                       // cannot coalesce with unequal TTL values
-                       return coalesceUnavailable
-               }
+       if !ipHeadersCanCoalesce(pkt, pktTarget) {
+               return coalesceUnavailable
        }
        // seq adjacency
        lhsLen := item.gsoSize
@@ -252,16 +385,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
        return coalesceUnavailable
 }
 
-func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
        srcAddrAt := ipv4SrcAddrOffset
        addrSize := 4
        if isV6 {
                srcAddrAt = ipv6SrcAddrOffset
                addrSize = 16
        }
-       tcpTotalLen := uint16(len(pkt) - int(iphLen))
-       tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
-       return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+       lenForPseudo := uint16(len(pkt) - int(iphLen))
+       cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+       return ^checksum(pkt[iphLen:], cSum) == 0
 }
 
 // coalesceResult represents the result of attempting to coalesce two TCP
@@ -276,8 +409,36 @@ const (
        coalesceSuccess
 )
 
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+       pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+       headersLen := item.iphLen + udphLen
+       coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+       if cap(pktHead)-bufsOffset < coalescedLen {
+               // We don't want to allocate a new underlying array if capacity is
+               // too small.
+               return coalesceInsufficientCap
+       }
+       if item.numMerged == 0 {
+               if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+                       return coalesceItemInvalidCSum
+               }
+       }
+       if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+               return coalescePktInvalidCSum
+       }
+       extendBy := len(pkt) - int(headersLen)
+       bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+       copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+       item.numMerged++
+       return coalesceSuccess
+}
+
 // coalesceTCPPackets attempts to coalesce pkt with the packet described by
-// item, returning the outcome. This function may swap bufs elements in the
+// item, and returns the outcome. This function may swap bufs elements in the
 // event of a prepend as item's bufs index is already being tracked for writing
 // to a Device.
 func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
@@ -297,11 +458,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
                        return coalescePSHEnding
                }
                if item.numMerged == 0 {
-                       if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+                       if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
                                return coalesceItemInvalidCSum
                        }
                }
-               if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+               if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
                        return coalescePktInvalidCSum
                }
                item.sentSeq = seq
@@ -319,11 +480,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
                        return coalesceInsufficientCap
                }
                if item.numMerged == 0 {
-                       if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+                       if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
                                return coalesceItemInvalidCSum
                        }
                }
-               if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+               if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
                        return coalescePktInvalidCSum
                }
                if pshSet {
@@ -354,52 +515,52 @@ const (
        maxUint16         = 1<<16 - 1
 )
 
-type tcpGROResult int
+type groResult int
 
 const (
-       tcpGROResultNoop tcpGROResult = iota
-       tcpGROResultTableInsert
-       tcpGROResultCoalesced
+       groResultNoop groResult = iota
+       groResultTableInsert
+       groResultCoalesced
 )
 
 // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
-// existing packets tracked in table. It returns a tcpGROResultNoop when no
-// action was taken, tcpGROResultTableInsert when the evaluated packet was
-// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
 // coalesced with another packet in table.
-func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
        pkt := bufs[pktI][offset:]
        if len(pkt) > maxUint16 {
                // A valid IPv4 or IPv6 packet will never exceed this.
-               return tcpGROResultNoop
+               return groResultNoop
        }
        iphLen := int((pkt[0] & 0x0F) * 4)
        if isV6 {
                iphLen = 40
                ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
                if ipv6HPayloadLen != len(pkt)-iphLen {
-                       return tcpGROResultNoop
+                       return groResultNoop
                }
        } else {
                totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
                if totalLen != len(pkt) {
-                       return tcpGROResultNoop
+                       return groResultNoop
                }
        }
        if len(pkt) < iphLen {
-               return tcpGROResultNoop
+               return groResultNoop
        }
        tcphLen := int((pkt[iphLen+12] >> 4) * 4)
        if tcphLen < 20 || tcphLen > 60 {
-               return tcpGROResultNoop
+               return groResultNoop
        }
        if len(pkt) < iphLen+tcphLen {
-               return tcpGROResultNoop
+               return groResultNoop
        }
        if !isV6 {
                if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
                        // no GRO support for fragmented segments for now
-                       return tcpGROResultNoop
+                       return groResultNoop
                }
        }
        tcpFlags := pkt[iphLen+tcpFlagsOffset]
@@ -407,14 +568,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
        // not a candidate if any non-ACK flags (except PSH+ACK) are set
        if tcpFlags != tcpFlagACK {
                if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
-                       return tcpGROResultNoop
+                       return groResultNoop
                }
                pshSet = true
        }
        gsoSize := uint16(len(pkt) - tcphLen - iphLen)
        // not a candidate if payload len is 0
        if gsoSize < 1 {
-               return tcpGROResultNoop
+               return groResultNoop
        }
        seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
        srcAddrOffset := ipv4SrcAddrOffset
@@ -425,7 +586,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
        }
        items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
        if !existing {
-               return tcpGROResultNoop
+               return groResultTableInsert
        }
        for i := len(items) - 1; i >= 0; i-- {
                // In the best case of packets arriving in order iterating in reverse is
@@ -443,54 +604,25 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
                        switch result {
                        case coalesceSuccess:
                                table.updateAt(item, i)
-                               return tcpGROResultCoalesced
+                               return groResultCoalesced
                        case coalesceItemInvalidCSum:
                                // delete the item with an invalid csum
                                table.deleteAt(item.key, i)
                        case coalescePktInvalidCSum:
                                // no point in inserting an item that we can't coalesce
-                               return tcpGROResultNoop
+                               return groResultNoop
                        default:
                        }
                }
        }
        // failed to coalesce with any other packets; store the item in the flow
        table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
-       return tcpGROResultTableInsert
-}
-
-func isTCP4NoIPOptions(b []byte) bool {
-       if len(b) < 40 {
-               return false
-       }
-       if b[0]>>4 != 4 {
-               return false
-       }
-       if b[0]&0x0F != 5 {
-               return false
-       }
-       if b[9] != unix.IPPROTO_TCP {
-               return false
-       }
-       return true
+       return groResultTableInsert
 }
 
-func isTCP6NoEH(b []byte) bool {
-       if len(b) < 60 {
-               return false
-       }
-       if b[0]>>4 != 6 {
-               return false
-       }
-       if b[6] != unix.IPPROTO_TCP {
-               return false
-       }
-       return true
-}
-
-// applyCoalesceAccounting updates bufs to account for coalescing based on the
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
 // metadata found in table.
-func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
        for _, items := range table.itemsByFlow {
                for _, item := range items {
                        if item.numMerged > 0 {
@@ -505,7 +637,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
 
                                // Recalculate the total len (IPv4) or payload len (IPv6).
                                // Recalculate the (IPv4) header checksum.
-                               if isV6 {
+                               if item.key.isV6 {
                                        hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
                                        binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
                                } else {
@@ -525,7 +657,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
                                // this with computation of the tcp header and payload checksum.
                                addrLen := 4
                                addrOffset := ipv4SrcAddrOffset
-                               if isV6 {
+                               if item.key.isV6 {
                                        addrLen = 16
                                        addrOffset = ipv6SrcAddrOffset
                                }
@@ -546,54 +678,245 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
        return nil
 }
 
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+       for _, items := range table.itemsByFlow {
+               for _, item := range items {
+                       if item.numMerged > 0 {
+                               hdr := virtioNetHdr{
+                                       flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+                                       hdrLen:     uint16(item.iphLen + udphLen),
+                                       gsoSize:    item.gsoSize,
+                                       csumStart:  uint16(item.iphLen),
+                                       csumOffset: 6,
+                               }
+                               pkt := bufs[item.bufsIndex][offset:]
+
+                               // Recalculate the total len (IPv4) or payload len (IPv6).
+                               // Recalculate the (IPv4) header checksum.
+                               hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+                               if item.key.isV6 {
+                                       binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+                               } else {
+                                       pkt[10], pkt[11] = 0, 0
+                                       binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+                                       iphCSum := ^checksum(pkt[:item.iphLen], 0)            // compute IPv4 header checksum
+                                       binary.BigEndian.PutUint16(pkt[10:], iphCSum)         // set IPv4 header checksum field
+                               }
+                               err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+                               if err != nil {
+                                       return err
+                               }
+
+                               // Recalculate the UDP len field value
+                               binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+                               // Calculate the pseudo header checksum and place it at the UDP
+                               // checksum offset. Downstream checksum offloading will combine
+                               // this with computation of the udp header and payload checksum.
+                               addrLen := 4
+                               addrOffset := ipv4SrcAddrOffset
+                               if item.key.isV6 {
+                                       addrLen = 16
+                                       addrOffset = ipv6SrcAddrOffset
+                               }
+                               srcAddrAt := offset + addrOffset
+                               srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+                               dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+                               psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+                               binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+                       } else {
+                               hdr := virtioNetHdr{}
+                               err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+                               if err != nil {
+                                       return err
+                               }
+                       }
+               }
+       }
+       return nil
+}
+
+type groCandidateType uint8
+
+const (
+       notGROCandidate groCandidateType = iota
+       tcp4GROCandidate
+       tcp6GROCandidate
+       udp4GROCandidate
+       udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+       if len(b) < 28 {
+               return notGROCandidate
+       }
+       if b[0]>>4 == 4 {
+               if b[0]&0x0F != 5 {
+                       // IPv4 packets w/IP options do not coalesce
+                       return notGROCandidate
+               }
+               if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+                       return tcp4GROCandidate
+               }
+               if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+                       return udp4GROCandidate
+               }
+       } else if b[0]>>4 == 6 {
+               if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+                       return tcp6GROCandidate
+               }
+               if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+                       return udp6GROCandidate
+               }
+       }
+       return notGROCandidate
+}
+
+const (
+       udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+       pkt := bufs[pktI][offset:]
+       if len(pkt) > maxUint16 {
+               // A valid IPv4 or IPv6 packet will never exceed this.
+               return groResultNoop
+       }
+       iphLen := int((pkt[0] & 0x0F) * 4)
+       if isV6 {
+               iphLen = 40
+               ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+               if ipv6HPayloadLen != len(pkt)-iphLen {
+                       return groResultNoop
+               }
+       } else {
+               totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+               if totalLen != len(pkt) {
+                       return groResultNoop
+               }
+       }
+       if len(pkt) < iphLen {
+               return groResultNoop
+       }
+       if len(pkt) < iphLen+udphLen {
+               return groResultNoop
+       }
+       if !isV6 {
+               if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+                       // no GRO support for fragmented segments for now
+                       return groResultNoop
+               }
+       }
+       gsoSize := uint16(len(pkt) - udphLen - iphLen)
+       // not a candidate if payload len is 0
+       if gsoSize < 1 {
+               return groResultNoop
+       }
+       srcAddrOffset := ipv4SrcAddrOffset
+       addrLen := 4
+       if isV6 {
+               srcAddrOffset = ipv6SrcAddrOffset
+               addrLen = 16
+       }
+       items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+       if !existing {
+               return groResultTableInsert
+       }
+       // With UDP we only check the last item, otherwise we could reorder packets
+       // for a given flow. We must also always insert a new item, or successfully
+       // coalesce with an existing item, for the same reason.
+       item := items[len(items)-1]
+       can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+       var pktCSumKnownInvalid bool
+       if can == coalesceAppend {
+               result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+               switch result {
+               case coalesceSuccess:
+                       table.updateAt(item, len(items)-1)
+                       return groResultCoalesced
+               case coalesceItemInvalidCSum:
+                       // If the existing item has an invalid csum we take no action. A new
+                       // item will be stored after it, and the existing item will never be
+                       // revisited as part of future coalescing candidacy checks.
+               case coalescePktInvalidCSum:
+                       // We must insert a new item, but we also mark it as invalid csum
+                       // to prevent a repeat checksum validation.
+                       pktCSumKnownInvalid = true
+               default:
+               }
+       }
+       // failed to coalesce with any other packets; store the item in the flow
+       table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+       return groResultTableInsert
+}
+
 // handleGRO evaluates bufs for GRO, and writes the indices of the resulting
-// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
 // empty (but non-nil), and are passed in to save allocs as the caller may reset
-// and recycle them across vectors of packets.
-func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
        for i := range bufs {
                if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
                        return errors.New("invalid offset")
                }
-               var result tcpGROResult
-               switch {
-               case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
-                       result = tcpGRO(bufs, offset, i, tcp4Table, false)
-               case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
-                       result = tcpGRO(bufs, offset, i, tcp6Table, true)
+               var result groResult
+               switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+               case tcp4GROCandidate:
+                       result = tcpGRO(bufs, offset, i, tcpTable, false)
+               case tcp6GROCandidate:
+                       result = tcpGRO(bufs, offset, i, tcpTable, true)
+               case udp4GROCandidate:
+                       result = udpGRO(bufs, offset, i, udpTable, false)
+               case udp6GROCandidate:
+                       result = udpGRO(bufs, offset, i, udpTable, true)
                }
                switch result {
-               case tcpGROResultNoop:
+               case groResultNoop:
                        hdr := virtioNetHdr{}
                        err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
                        if err != nil {
                                return err
                        }
                        fallthrough
-               case tcpGROResultTableInsert:
+               case groResultTableInsert:
                        *toWrite = append(*toWrite, i)
                }
        }
-       err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
-       err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
-       return errors.Join(err4, err6)
+       errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+       errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+       return errors.Join(errTCP, errUDP)
 }
 
-// tcpTSO splits packets from in into outBuffs, writing the size of each
+// gsoSplit splits packets from in into outBuffs, writing the size of each
 // element into sizes. It returns the number of buffers populated, and/or an
 // error.
-func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
        iphLen := int(hdr.csumStart)
        srcAddrOffset := ipv6SrcAddrOffset
        addrLen := 16
-       if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+       if !isV6 {
                in[10], in[11] = 0, 0 // clear ipv4 header checksum
                srcAddrOffset = ipv4SrcAddrOffset
                addrLen = 4
        }
-       tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
-       in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
-       firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+       transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+       in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+       var firstTCPSeqNum uint32
+       var protocol uint8
+       if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+               protocol = unix.IPPROTO_TCP
+               firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+       } else {
+               protocol = unix.IPPROTO_UDP
+       }
        nextSegmentDataAt := int(hdr.hdrLen)
        i := 0
        for ; nextSegmentDataAt < len(in); i++ {
@@ -610,7 +933,7 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
                out := outBuffs[i][outOffset:]
 
                copy(out, in[:iphLen])
-               if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+               if !isV6 {
                        // For IPv4 we are responsible for incrementing the ID field,
                        // updating the total len field, and recalculating the header
                        // checksum.
@@ -627,25 +950,32 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
                        binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
                }
 
-               // TCP header
+               // copy transport header
                copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
-               tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
-               binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
-               if nextSegmentEnd != len(in) {
-                       // FIN and PSH should only be set on last segment
-                       clearFlags := tcpFlagFIN | tcpFlagPSH
-                       out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+
+               if protocol == unix.IPPROTO_TCP {
+                       // set TCP seq and adjust TCP flags
+                       tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+                       binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+                       if nextSegmentEnd != len(in) {
+                               // FIN and PSH should only be set on last segment
+                               clearFlags := tcpFlagFIN | tcpFlagPSH
+                               out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+                       }
+               } else {
+                       // set UDP header len
+                       binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
                }
 
                // payload
                copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
 
-               // TCP checksum
-               tcpHLen := int(hdr.hdrLen - hdr.csumStart)
-               tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
-               tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
-               tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
-               binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+               // transport checksum
+               transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+               lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+               transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+               transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+               binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
 
                nextSegmentDataAt += int(hdr.gsoSize)
        }
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644 (file)
index 0000000..ae55c8c
--- /dev/null
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+       "net/netip"
+       "testing"
+
+       "golang.org/x/sys/unix"
+       "golang.zx2c4.com/wireguard/conn"
+       "gvisor.dev/gvisor/pkg/tcpip"
+       "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+       offset = virtioNetHdrLen
+)
+
+var (
+       ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+       ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+       ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+       ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+       ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+       ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+       totalLen := 28 + payloadLen
+       b := make([]byte, offset+int(totalLen), 65535)
+       ipv4H := header.IPv4(b[offset:])
+       srcAs4 := srcIPPort.Addr().As4()
+       dstAs4 := dstIPPort.Addr().As4()
+       ipFields := &header.IPv4Fields{
+               SrcAddr:     tcpip.AddrFromSlice(srcAs4[:]),
+               DstAddr:     tcpip.AddrFromSlice(dstAs4[:]),
+               Protocol:    unix.IPPROTO_UDP,
+               TTL:         64,
+               TotalLength: uint16(totalLen),
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv4H.Encode(ipFields)
+       udpH := header.UDP(b[offset+20:])
+       udpH.Encode(&header.UDPFields{
+               SrcPort: srcIPPort.Port(),
+               DstPort: dstIPPort.Port(),
+               Length:  uint16(payloadLen + udphLen),
+       })
+       ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+       udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+       return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+       return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+       totalLen := 48 + payloadLen
+       b := make([]byte, offset+int(totalLen), 65535)
+       ipv6H := header.IPv6(b[offset:])
+       srcAs16 := srcIPPort.Addr().As16()
+       dstAs16 := dstIPPort.Addr().As16()
+       ipFields := &header.IPv6Fields{
+               SrcAddr:           tcpip.AddrFromSlice(srcAs16[:]),
+               DstAddr:           tcpip.AddrFromSlice(dstAs16[:]),
+               TransportProtocol: unix.IPPROTO_UDP,
+               HopLimit:          64,
+               PayloadLength:     uint16(payloadLen + udphLen),
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv6H.Encode(ipFields)
+       udpH := header.UDP(b[offset+40:])
+       udpH.Encode(&header.UDPFields{
+               SrcPort: srcIPPort.Port(),
+               DstPort: dstIPPort.Port(),
+               Length:  uint16(payloadLen + udphLen),
+       })
+       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+       udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+       return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+       return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+       totalLen := 40 + segmentSize
+       b := make([]byte, offset+int(totalLen), 65535)
+       ipv4H := header.IPv4(b[offset:])
+       srcAs4 := srcIPPort.Addr().As4()
+       dstAs4 := dstIPPort.Addr().As4()
+       ipFields := &header.IPv4Fields{
+               SrcAddr:     tcpip.AddrFromSlice(srcAs4[:]),
+               DstAddr:     tcpip.AddrFromSlice(dstAs4[:]),
+               Protocol:    unix.IPPROTO_TCP,
+               TTL:         64,
+               TotalLength: uint16(totalLen),
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv4H.Encode(ipFields)
+       tcpH := header.TCP(b[offset+20:])
+       tcpH.Encode(&header.TCPFields{
+               SrcPort:    srcIPPort.Port(),
+               DstPort:    dstIPPort.Port(),
+               SeqNum:     seq,
+               AckNum:     1,
+               DataOffset: 20,
+               Flags:      flags,
+               WindowSize: 3000,
+       })
+       ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+       tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+       return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+       return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+       totalLen := 60 + segmentSize
+       b := make([]byte, offset+int(totalLen), 65535)
+       ipv6H := header.IPv6(b[offset:])
+       srcAs16 := srcIPPort.Addr().As16()
+       dstAs16 := dstIPPort.Addr().As16()
+       ipFields := &header.IPv6Fields{
+               SrcAddr:           tcpip.AddrFromSlice(srcAs16[:]),
+               DstAddr:           tcpip.AddrFromSlice(dstAs16[:]),
+               TransportProtocol: unix.IPPROTO_TCP,
+               HopLimit:          64,
+               PayloadLength:     uint16(segmentSize + 20),
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv6H.Encode(ipFields)
+       tcpH := header.TCP(b[offset+40:])
+       tcpH.Encode(&header.TCPFields{
+               SrcPort:    srcIPPort.Port(),
+               DstPort:    dstIPPort.Port(),
+               SeqNum:     seq,
+               AckNum:     1,
+               DataOffset: 20,
+               Flags:      flags,
+               WindowSize: 3000,
+       })
+       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+       tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+       return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+       return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+       tests := []struct {
+               name     string
+               hdr      virtioNetHdr
+               pktIn    []byte
+               wantLens []int
+               wantErr  bool
+       }{
+               {
+                       "tcp4",
+                       virtioNetHdr{
+                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+                               gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV4,
+                               gsoSize:    100,
+                               hdrLen:     40,
+                               csumStart:  20,
+                               csumOffset: 16,
+                       },
+                       tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+                       []int{140, 140},
+                       false,
+               },
+               {
+                       "tcp6",
+                       virtioNetHdr{
+                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+                               gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV6,
+                               gsoSize:    100,
+                               hdrLen:     60,
+                               csumStart:  40,
+                               csumOffset: 16,
+                       },
+                       tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+                       []int{160, 160},
+                       false,
+               },
+               {
+                       "udp4",
+                       virtioNetHdr{
+                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+                               gsoType:    unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+                               gsoSize:    100,
+                               hdrLen:     28,
+                               csumStart:  20,
+                               csumOffset: 6,
+                       },
+                       udp4Packet(ip4PortA, ip4PortB, 200),
+                       []int{128, 128},
+                       false,
+               },
+               {
+                       "udp6",
+                       virtioNetHdr{
+                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+                               gsoType:    unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+                               gsoSize:    100,
+                               hdrLen:     48,
+                               csumStart:  40,
+                               csumOffset: 6,
+                       },
+                       udp6Packet(ip6PortA, ip6PortB, 200),
+                       []int{148, 148},
+                       false,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       out := make([][]byte, conn.IdealBatchSize)
+                       sizes := make([]int, conn.IdealBatchSize)
+                       for i := range out {
+                               out[i] = make([]byte, 65535)
+                       }
+                       tt.hdr.encode(tt.pktIn)
+                       n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+                       if err != nil {
+                               if tt.wantErr {
+                                       return
+                               }
+                               t.Fatalf("got err: %v", err)
+                       }
+                       if n != len(tt.wantLens) {
+                               t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+                       }
+                       for i := range tt.wantLens {
+                               if tt.wantLens[i] != sizes[i] {
+                                       t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+                               }
+                       }
+               })
+       }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+       at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+       b[at] ^= 0xFF
+       b[at+1] ^= 0xFF
+       return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+       at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+       b[at] ^= 0xFF
+       b[at+1] ^= 0xFF
+       return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+       pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+       pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+       pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+       pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+       pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+       pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+       pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+       pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+       pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+       pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+       pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+       pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+       f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+       f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+               pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+               toWrite := make([]int, 0, len(pkts))
+               handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+               if len(toWrite) > len(pkts) {
+                       t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+               }
+               seenWriteI := make(map[int]bool)
+               for _, writeI := range toWrite {
+                       if writeI < 0 || writeI > len(pkts)-1 {
+                               t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+                       }
+                       if seenWriteI[writeI] {
+                               t.Errorf("duplicate toWrite value: %d", writeI)
+                       }
+                       seenWriteI[writeI] = true
+               }
+       })
+}
+
+func Test_handleGRO(t *testing.T) {
+       tests := []struct {
+               name        string
+               pktsIn      [][]byte
+               canUDPGRO   bool
+               wantToWrite []int
+               wantLens    []int
+               wantErr     bool
+       }{
+               {
+                       "multiple protocols and flows",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // tcp4 flow 1
+                               udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
+                               udp4Packet(ip4PortA, ip4PortC, 100),                         // udp4 flow 2
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+                               tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // tcp6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+                               udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
+                               udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
+                               udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
+                       },
+                       true,
+                       []int{0, 1, 2, 4, 5, 7, 9},
+                       []int{240, 228, 128, 140, 260, 160, 248},
+                       false,
+               },
+               {
+                       "multiple protocols and flows no UDP GRO",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // tcp4 flow 1
+                               udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
+                               udp4Packet(ip4PortA, ip4PortC, 100),                         // udp4 flow 2
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+                               tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // tcp6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+                               udp4Packet(ip4PortA, ip4PortB, 100),                         // udp4 flow 1
+                               udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
+                               udp6Packet(ip6PortA, ip6PortB, 100),                         // udp6 flow 1
+                       },
+                       false,
+                       []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+                       []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+                       false,
+               },
+               {
+                       "PSH interleaved",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),                     // v4 flow 1
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                   // v4 flow 1
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301),                   // v4 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),                     // v6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201),                   // v6 flow 1
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301),                   // v6 flow 1
+                       },
+                       true,
+                       []int{0, 2, 4, 6},
+                       []int{240, 240, 260, 260},
+                       false,
+               },
+               {
+                       "coalesceItemInvalidCSum",
+                       [][]byte{
+                               flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101),                 // v4 flow 1 seq 101 len 100
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                 // v4 flow 1 seq 201 len 100
+                               flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                       },
+                       true,
+                       []int{0, 1, 3, 4},
+                       []int{140, 240, 128, 228},
+                       false,
+               },
+               {
+                       "out of order",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1 seq 1 len 100
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+                       },
+                       true,
+                       []int{0},
+                       []int{340},
+                       false,
+               },
+               {
+                       "unequal TTL",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.TTL++
+                               }),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                               udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+                                       fields.TTL++
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{140, 140, 128, 128},
+                       false,
+               },
+               {
+                       "unequal ToS",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.TOS++
+                               }),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                               udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+                                       fields.TOS++
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{140, 140, 128, 128},
+                       false,
+               },
+               {
+                       "unequal flags more fragments set",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 1
+                               }),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                               udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 1
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{140, 140, 128, 128},
+                       false,
+               },
+               {
+                       "unequal flags DF set",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 2
+                               }),
+                               udp4Packet(ip4PortA, ip4PortB, 100),
+                               udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 2
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{140, 140, 128, 128},
+                       false,
+               },
+               {
+                       "ipv6 unequal hop limit",
+                       [][]byte{
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+                                       fields.HopLimit++
+                               }),
+                               udp6Packet(ip6PortA, ip6PortB, 100),
+                               udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+                                       fields.HopLimit++
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{160, 160, 148, 148},
+                       false,
+               },
+               {
+                       "ipv6 unequal traffic class",
+                       [][]byte{
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+                                       fields.TrafficClass++
+                               }),
+                               udp6Packet(ip6PortA, ip6PortB, 100),
+                               udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+                                       fields.TrafficClass++
+                               }),
+                       },
+                       true,
+                       []int{0, 1, 2, 3},
+                       []int{160, 160, 148, 148},
+                       false,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       toWrite := make([]int, 0, len(tt.pktsIn))
+                       err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+                       if err != nil {
+                               if tt.wantErr {
+                                       return
+                               }
+                               t.Fatalf("got err: %v", err)
+                       }
+                       if len(toWrite) != len(tt.wantToWrite) {
+                               t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+                       }
+                       for i, pktI := range tt.wantToWrite {
+                               if tt.wantToWrite[i] != toWrite[i] {
+                                       t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+                               }
+                               if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+                                       t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+                               }
+                       }
+               })
+       }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+       tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+       tcp4TooShort := tcp4[:39]
+       ip4InvalidHeaderLen := make([]byte, len(tcp4))
+       copy(ip4InvalidHeaderLen, tcp4)
+       ip4InvalidHeaderLen[0] = 0x46
+       ip4InvalidProtocol := make([]byte, len(tcp4))
+       copy(ip4InvalidProtocol, tcp4)
+       ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+       tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+       tcp6TooShort := tcp6[:59]
+       ip6InvalidProtocol := make([]byte, len(tcp6))
+       copy(ip6InvalidProtocol, tcp6)
+       ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+       udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+       udp4TooShort := udp4[:27]
+
+       udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+       udp6TooShort := udp6[:47]
+
+       tests := []struct {
+               name      string
+               b         []byte
+               canUDPGRO bool
+               want      groCandidateType
+       }{
+               {
+                       "tcp4",
+                       tcp4,
+                       true,
+                       tcp4GROCandidate,
+               },
+               {
+                       "tcp6",
+                       tcp6,
+                       true,
+                       tcp6GROCandidate,
+               },
+               {
+                       "udp4",
+                       udp4,
+                       true,
+                       udp4GROCandidate,
+               },
+               {
+                       "udp4 no support",
+                       udp4,
+                       false,
+                       notGROCandidate,
+               },
+               {
+                       "udp6",
+                       udp6,
+                       true,
+                       udp6GROCandidate,
+               },
+               {
+                       "udp6 no support",
+                       udp6,
+                       false,
+                       notGROCandidate,
+               },
+               {
+                       "udp4 too short",
+                       udp4TooShort,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "udp6 too short",
+                       udp6TooShort,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "tcp4 too short",
+                       tcp4TooShort,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "tcp6 too short",
+                       tcp6TooShort,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "invalid IP version",
+                       []byte{0x00},
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "invalid IP header len",
+                       ip4InvalidHeaderLen,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "ip4 invalid protocol",
+                       ip4InvalidProtocol,
+                       true,
+                       notGROCandidate,
+               },
+               {
+                       "ip6 invalid protocol",
+                       ip6InvalidProtocol,
+                       true,
+                       notGROCandidate,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+                               t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+       udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+       udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+       udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+       type args struct {
+               pkt        []byte
+               iphLen     uint8
+               gsoSize    uint16
+               item       udpGROItem
+               bufs       [][]byte
+               bufsOffset int
+       }
+       tests := []struct {
+               name string
+               args args
+               want canCoalesce
+       }{
+               {
+                       "coalesceAppend equal gso",
+                       args{
+                               pkt:     udp4a[offset:],
+                               iphLen:  20,
+                               gsoSize: 100,
+                               item: udpGROItem{
+                                       gsoSize: 100,
+                                       iphLen:  20,
+                               },
+                               bufs: [][]byte{
+                                       udp4a,
+                                       udp4b,
+                               },
+                               bufsOffset: offset,
+                       },
+                       coalesceAppend,
+               },
+               {
+                       "coalesceAppend smaller gso",
+                       args{
+                               pkt:     udp4a[offset : len(udp4a)-90],
+                               iphLen:  20,
+                               gsoSize: 10,
+                               item: udpGROItem{
+                                       gsoSize: 100,
+                                       iphLen:  20,
+                               },
+                               bufs: [][]byte{
+                                       udp4a,
+                                       udp4b,
+                               },
+                               bufsOffset: offset,
+                       },
+                       coalesceAppend,
+               },
+               {
+                       "coalesceUnavailable smaller gso previously appended",
+                       args{
+                               pkt:     udp4a[offset:],
+                               iphLen:  20,
+                               gsoSize: 100,
+                               item: udpGROItem{
+                                       gsoSize: 100,
+                                       iphLen:  20,
+                               },
+                               bufs: [][]byte{
+                                       udp4c,
+                                       udp4b,
+                               },
+                               bufsOffset: offset,
+                       },
+                       coalesceUnavailable,
+               },
+               {
+                       "coalesceUnavailable larger following smaller",
+                       args{
+                               pkt:     udp4c[offset:],
+                               iphLen:  20,
+                               gsoSize: 110,
+                               item: udpGROItem{
+                                       gsoSize: 100,
+                                       iphLen:  20,
+                               },
+                               bufs: [][]byte{
+                                       udp4a,
+                                       udp4c,
+                               },
+                               bufsOffset: offset,
+                       },
+                       coalesceUnavailable,
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+                               t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+                       }
+               })
+       }
+}
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
deleted file mode 100644 (file)
index ddddc48..0000000
+++ /dev/null
@@ -1,411 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package tun
-
-import (
-       "net/netip"
-       "testing"
-
-       "golang.org/x/sys/unix"
-       "golang.zx2c4.com/wireguard/conn"
-       "gvisor.dev/gvisor/pkg/tcpip"
-       "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-const (
-       offset = virtioNetHdrLen
-)
-
-var (
-       ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
-       ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
-       ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
-       ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
-       ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
-       ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
-)
-
-func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
-       totalLen := 40 + segmentSize
-       b := make([]byte, offset+int(totalLen), 65535)
-       ipv4H := header.IPv4(b[offset:])
-       srcAs4 := srcIPPort.Addr().As4()
-       dstAs4 := dstIPPort.Addr().As4()
-       ipFields := &header.IPv4Fields{
-               SrcAddr:     tcpip.AddrFromSlice(srcAs4[:]),
-               DstAddr:     tcpip.AddrFromSlice(dstAs4[:]),
-               Protocol:    unix.IPPROTO_TCP,
-               TTL:         64,
-               TotalLength: uint16(totalLen),
-       }
-       if ipFn != nil {
-               ipFn(ipFields)
-       }
-       ipv4H.Encode(ipFields)
-       tcpH := header.TCP(b[offset+20:])
-       tcpH.Encode(&header.TCPFields{
-               SrcPort:    srcIPPort.Port(),
-               DstPort:    dstIPPort.Port(),
-               SeqNum:     seq,
-               AckNum:     1,
-               DataOffset: 20,
-               Flags:      flags,
-               WindowSize: 3000,
-       })
-       ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
-       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
-       tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
-       return b
-}
-
-func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
-       return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
-}
-
-func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
-       totalLen := 60 + segmentSize
-       b := make([]byte, offset+int(totalLen), 65535)
-       ipv6H := header.IPv6(b[offset:])
-       srcAs16 := srcIPPort.Addr().As16()
-       dstAs16 := dstIPPort.Addr().As16()
-       ipFields := &header.IPv6Fields{
-               SrcAddr:           tcpip.AddrFromSlice(srcAs16[:]),
-               DstAddr:           tcpip.AddrFromSlice(dstAs16[:]),
-               TransportProtocol: unix.IPPROTO_TCP,
-               HopLimit:          64,
-               PayloadLength:     uint16(segmentSize + 20),
-       }
-       if ipFn != nil {
-               ipFn(ipFields)
-       }
-       ipv6H.Encode(ipFields)
-       tcpH := header.TCP(b[offset+40:])
-       tcpH.Encode(&header.TCPFields{
-               SrcPort:    srcIPPort.Port(),
-               DstPort:    dstIPPort.Port(),
-               SeqNum:     seq,
-               AckNum:     1,
-               DataOffset: 20,
-               Flags:      flags,
-               WindowSize: 3000,
-       })
-       pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
-       tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
-       return b
-}
-
-func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
-       return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
-}
-
-func Test_handleVirtioRead(t *testing.T) {
-       tests := []struct {
-               name     string
-               hdr      virtioNetHdr
-               pktIn    []byte
-               wantLens []int
-               wantErr  bool
-       }{
-               {
-                       "tcp4",
-                       virtioNetHdr{
-                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
-                               gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV4,
-                               gsoSize:    100,
-                               hdrLen:     40,
-                               csumStart:  20,
-                               csumOffset: 16,
-                       },
-                       tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
-                       []int{140, 140},
-                       false,
-               },
-               {
-                       "tcp6",
-                       virtioNetHdr{
-                               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
-                               gsoType:    unix.VIRTIO_NET_HDR_GSO_TCPV6,
-                               gsoSize:    100,
-                               hdrLen:     60,
-                               csumStart:  40,
-                               csumOffset: 16,
-                       },
-                       tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
-                       []int{160, 160},
-                       false,
-               },
-       }
-
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       out := make([][]byte, conn.IdealBatchSize)
-                       sizes := make([]int, conn.IdealBatchSize)
-                       for i := range out {
-                               out[i] = make([]byte, 65535)
-                       }
-                       tt.hdr.encode(tt.pktIn)
-                       n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
-                       if err != nil {
-                               if tt.wantErr {
-                                       return
-                               }
-                               t.Fatalf("got err: %v", err)
-                       }
-                       if n != len(tt.wantLens) {
-                               t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
-                       }
-                       for i := range tt.wantLens {
-                               if tt.wantLens[i] != sizes[i] {
-                                       t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
-                               }
-                       }
-               })
-       }
-}
-
-func flipTCP4Checksum(b []byte) []byte {
-       at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
-       b[at] ^= 0xFF
-       b[at+1] ^= 0xFF
-       return b
-}
-
-func Fuzz_handleGRO(f *testing.F) {
-       pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
-       pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
-       pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
-       pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
-       pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
-       pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
-       f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
-       f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
-               pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
-               toWrite := make([]int, 0, len(pkts))
-               handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
-               if len(toWrite) > len(pkts) {
-                       t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
-               }
-               seenWriteI := make(map[int]bool)
-               for _, writeI := range toWrite {
-                       if writeI < 0 || writeI > len(pkts)-1 {
-                               t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
-                       }
-                       if seenWriteI[writeI] {
-                               t.Errorf("duplicate toWrite value: %d", writeI)
-                       }
-                       seenWriteI[writeI] = true
-               }
-       })
-}
-
-func Test_handleGRO(t *testing.T) {
-       tests := []struct {
-               name        string
-               pktsIn      [][]byte
-               wantToWrite []int
-               wantLens    []int
-               wantErr     bool
-       }{
-               {
-                       "multiple flows",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
-                               tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),   // v6 flow 1
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
-                               tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
-                       },
-                       []int{0, 2, 3, 5},
-                       []int{240, 140, 260, 160},
-                       false,
-               },
-               {
-                       "PSH interleaved",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),                     // v4 flow 1
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                   // v4 flow 1
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301),                   // v4 flow 1
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),                     // v6 flow 1
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201),                   // v6 flow 1
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301),                   // v6 flow 1
-                       },
-                       []int{0, 2, 4, 6},
-                       []int{240, 240, 260, 260},
-                       false,
-               },
-               {
-                       "coalesceItemInvalidCSum",
-                       [][]byte{
-                               flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101),                 // v4 flow 1 seq 101 len 100
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201),                 // v4 flow 1 seq 201 len 100
-                       },
-                       []int{0, 1},
-                       []int{140, 240},
-                       false,
-               },
-               {
-                       "out of order",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),   // v4 flow 1 seq 1 len 100
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
-                       },
-                       []int{0},
-                       []int{340},
-                       false,
-               },
-               {
-                       "tcp4 unequal TTL",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
-                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
-                                       fields.TTL++
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{140, 140},
-                       false,
-               },
-               {
-                       "tcp4 unequal ToS",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
-                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
-                                       fields.TOS++
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{140, 140},
-                       false,
-               },
-               {
-                       "tcp4 unequal flags more fragments set",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
-                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
-                                       fields.Flags = 1
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{140, 140},
-                       false,
-               },
-               {
-                       "tcp4 unequal flags DF set",
-                       [][]byte{
-                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
-                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
-                                       fields.Flags = 2
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{140, 140},
-                       false,
-               },
-               {
-                       "tcp6 unequal hop limit",
-                       [][]byte{
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
-                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
-                                       fields.HopLimit++
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{160, 160},
-                       false,
-               },
-               {
-                       "tcp6 unequal traffic class",
-                       [][]byte{
-                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
-                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
-                                       fields.TrafficClass++
-                               }),
-                       },
-                       []int{0, 1},
-                       []int{160, 160},
-                       false,
-               },
-       }
-
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       toWrite := make([]int, 0, len(tt.pktsIn))
-                       err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
-                       if err != nil {
-                               if tt.wantErr {
-                                       return
-                               }
-                               t.Fatalf("got err: %v", err)
-                       }
-                       if len(toWrite) != len(tt.wantToWrite) {
-                               t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
-                       }
-                       for i, pktI := range tt.wantToWrite {
-                               if tt.wantToWrite[i] != toWrite[i] {
-                                       t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
-                               }
-                               if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
-                                       t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
-                               }
-                       }
-               })
-       }
-}
-
-func Test_isTCP4NoIPOptions(t *testing.T) {
-       valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
-       invalidLen := valid[:39]
-       invalidHeaderLen := make([]byte, len(valid))
-       copy(invalidHeaderLen, valid)
-       invalidHeaderLen[0] = 0x46
-       invalidProtocol := make([]byte, len(valid))
-       copy(invalidProtocol, valid)
-       invalidProtocol[9] = unix.IPPROTO_TCP + 1
-
-       tests := []struct {
-               name string
-               b    []byte
-               want bool
-       }{
-               {
-                       "valid",
-                       valid,
-                       true,
-               },
-               {
-                       "invalid length",
-                       invalidLen,
-                       false,
-               },
-               {
-                       "invalid version",
-                       []byte{0x00},
-                       false,
-               },
-               {
-                       "invalid header len",
-                       invalidHeaderLen,
-                       false,
-               },
-               {
-                       "invalid protocol",
-                       invalidProtocol,
-                       false,
-               },
-       }
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       if got := isTCP4NoIPOptions(tt.b); got != tt.want {
-                               t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want)
-                       }
-               })
-       }
-}
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
deleted file mode 100644 (file)
index 5461e79..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-go test fuzz v1
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-int(34)
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
deleted file mode 100644 (file)
index b441819..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-go test fuzz v1
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-int(-48)
index 12cd49f747dc120e84fb208d780da31a39121dfd..bd69cb552c58610c1c6716d349fe3f3e535137f4 100644 (file)
@@ -38,6 +38,7 @@ type NativeTun struct {
        statusListenersShutdown chan struct{}
        batchSize               int
        vnetHdr                 bool
+       udpGSO                  bool
 
        closeOnce sync.Once
 
@@ -48,9 +49,10 @@ type NativeTun struct {
        readOpMu sync.Mutex                    // readOpMu guards readBuff
        readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
 
-       writeOpMu                  sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
-       toWrite                    []int
-       tcp4GROTable, tcp6GROTable *tcpGROTable
+       writeOpMu   sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+       toWrite     []int
+       tcpGROTable *tcpGROTable
+       udpGROTable *udpGROTable
 }
 
 func (tun *NativeTun) File() *os.File {
@@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) {
 func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
        tun.writeOpMu.Lock()
        defer func() {
-               tun.tcp4GROTable.reset()
-               tun.tcp6GROTable.reset()
+               tun.tcpGROTable.reset()
+               tun.udpGROTable.reset()
                tun.writeOpMu.Unlock()
        }()
        var (
@@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
        )
        tun.toWrite = tun.toWrite[:0]
        if tun.vnetHdr {
-               err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+               err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
                if err != nil {
                        return 0, err
                }
@@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
                sizes[0] = n
                return 1, nil
        }
-       if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+       if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
                return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
        }
 
        ipVersion := in[0] >> 4
        switch ipVersion {
        case 4:
-               if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+               if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
                        return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
                }
        case 6:
-               if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+               if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
                        return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
                }
        default:
                return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
        }
 
-       if len(in) <= int(hdr.csumStart+12) {
-               return 0, errors.New("packet is too short")
-       }
        // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
        // of the entire first packet when the kernel is handling it as part of a
-       // FORWARD path. Instead, parse the TCP header length and add it onto
+       // FORWARD path. Instead, parse the transport header length and add it onto
        // csumStart, which is synonymous for IP header length.
-       tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
-       if tcpHLen < 20 || tcpHLen > 60 {
-               // A TCP header must be between 20 and 60 bytes in length.
-               return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+       if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+               hdr.hdrLen = hdr.csumStart + 8
+       } else {
+               if len(in) <= int(hdr.csumStart+12) {
+                       return 0, errors.New("packet is too short")
+               }
+
+               tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+               if tcpHLen < 20 || tcpHLen > 60 {
+                       // A TCP header must be between 20 and 60 bytes in length.
+                       return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+               }
+               hdr.hdrLen = hdr.csumStart + tcpHLen
        }
-       hdr.hdrLen = hdr.csumStart + tcpHLen
 
        if len(in) < int(hdr.hdrLen) {
                return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
@@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
                return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
        }
 
-       return tcpTSO(in, hdr, bufs, sizes, offset)
+       return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
 }
 
 func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
@@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int {
 
 const (
        // TODO: support TSO with ECN bits
-       tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+       tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+       tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
 )
 
 func (tun *NativeTun) initFromFlags(name string) error {
@@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error {
                }
                got := ifr.Uint16()
                if got&unix.IFF_VNET_HDR != 0 {
-                       err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+                       // tunTCPOffloads were added in Linux v2.6. We require their support
+                       // if IFF_VNET_HDR is set.
+                       err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
                        if err != nil {
                                return
                        }
                        tun.vnetHdr = true
                        tun.batchSize = conn.IdealBatchSize
+                       // tunUDPOffloads were added in Linux v6.2. We do not return an
+                       // error if they are unsupported at runtime.
+                       tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
                } else {
                        tun.batchSize = 1
                }
@@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
                events:                  make(chan Event, 5),
                errors:                  make(chan error, 5),
                statusListenersShutdown: make(chan struct{}),
-               tcp4GROTable:            newTCPGROTable(),
-               tcp6GROTable:            newTCPGROTable(),
+               tcpGROTable:             newTCPGROTable(),
+               udpGROTable:             newUDPGROTable(),
                toWrite:                 make([]int, 0, conn.IdealBatchSize),
        }
 
@@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
        }
        file := os.NewFile(uintptr(fd), "/dev/tun")
        tun := &NativeTun{
-               tunFile:      file,
-               events:       make(chan Event, 5),
-               errors:       make(chan error, 5),
-               tcp4GROTable: newTCPGROTable(),
-               tcp6GROTable: newTCPGROTable(),
-               toWrite:      make([]int, 0, conn.IdealBatchSize),
+               tunFile:     file,
+               events:      make(chan Event, 5),
+               errors:      make(chan error, 5),
+               tcpGROTable: newTCPGROTable(),
+               udpGROTable: newUDPGROTable(),
+               toWrite:     make([]int, 0, conn.IdealBatchSize),
        }
        name, err := tun.Name()
        if err != nil {