]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: reduce redundant checksumming in tcpGRO()
authorJordan Whited <jordan@tailscale.com>
Mon, 2 Oct 2023 21:46:13 +0000 (14:46 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Oct 2023 13:07:36 +0000 (15:07 +0200)
IPv4 header and pseudo header checksums were being computed on every
merge operation. Additionally, virtioNetHdr was being written at the
same time. This delays those operations until after all coalescing has
occurred.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tcp_offload_linux.go

index 39a7180c5a2a50f65901e630fb2e12f02e6dc7d6..1afd27edfb331c81e78f228709f0796f067482d9 100644 (file)
@@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
 type coalesceResult int
 
 const (
-       coalesceInsufficientCap coalesceResult = 0
-       coalescePSHEnding       coalesceResult = 1
-       coalesceItemInvalidCSum coalesceResult = 2
-       coalescePktInvalidCSum  coalesceResult = 3
-       coalesceSuccess         coalesceResult = 4
+       coalesceInsufficientCap coalesceResult = iota
+       coalescePSHEnding
+       coalesceItemInvalidCSum
+       coalescePktInvalidCSum
+       coalesceSuccess
 )
 
 // coalesceTCPPackets attempts to coalesce pkt with the packet described by
@@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
        if gsoSize > item.gsoSize {
                item.gsoSize = gsoSize
        }
-       hdr := virtioNetHdr{
-               flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
-               hdrLen:     uint16(headersLen),
-               gsoSize:    uint16(item.gsoSize),
-               csumStart:  uint16(item.iphLen),
-               csumOffset: 16,
-       }
-
-       // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
-       // (IPv4) header checksum.
-       if isV6 {
-               hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
-               binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
-       } else {
-               hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
-               pktHead[10], pktHead[11] = 0, 0                               // clear checksum field
-               binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
-               iphCSum := ^checksum(pktHead[:item.iphLen], 0)                // compute checksum
-               binary.BigEndian.PutUint16(pktHead[10:], iphCSum)             // set checksum field
-       }
-       hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
-
-       // Calculate the pseudo header checksum and place it at the TCP checksum
-       // offset. Downstream checksum offloading will combine this with computation
-       // of the tcp header and payload checksum.
-       addrLen := 4
-       addrOffset := ipv4SrcAddrOffset
-       if isV6 {
-               addrLen = 16
-               addrOffset = ipv6SrcAddrOffset
-       }
-       srcAddrAt := bufsOffset + addrOffset
-       srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
-       dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
-       psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
-       binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
 
        item.numMerged++
        return coalesceSuccess
@@ -390,43 +354,52 @@ const (
        maxUint16         = 1<<16 - 1
 )
 
+type tcpGROResult int
+
+const (
+       tcpGROResultNoop tcpGROResult = iota
+       tcpGROResultTableInsert
+       tcpGROResultCoalesced
+)
+
 // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
-// existing packets tracked in table. It will return false when pktI is not
-// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
-// should be written to the Device.
-func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
+// existing packets tracked in table. It returns a tcpGROResultNoop when no
+// action was taken, tcpGROResultTableInsert when the evaluated packet was
+// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
        pkt := bufs[pktI][offset:]
        if len(pkt) > maxUint16 {
                // A valid IPv4 or IPv6 packet will never exceed this.
-               return false
+               return tcpGROResultNoop
        }
        iphLen := int((pkt[0] & 0x0F) * 4)
        if isV6 {
                iphLen = 40
                ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
                if ipv6HPayloadLen != len(pkt)-iphLen {
-                       return false
+                       return tcpGROResultNoop
                }
        } else {
                totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
                if totalLen != len(pkt) {
-                       return false
+                       return tcpGROResultNoop
                }
        }
        if len(pkt) < iphLen {
-               return false
+               return tcpGROResultNoop
        }
        tcphLen := int((pkt[iphLen+12] >> 4) * 4)
        if tcphLen < 20 || tcphLen > 60 {
-               return false
+               return tcpGROResultNoop
        }
        if len(pkt) < iphLen+tcphLen {
-               return false
+               return tcpGROResultNoop
        }
        if !isV6 {
                if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
                        // no GRO support for fragmented segments for now
-                       return false
+                       return tcpGROResultNoop
                }
        }
        tcpFlags := pkt[iphLen+tcpFlagsOffset]
@@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
        // not a candidate if any non-ACK flags (except PSH+ACK) are set
        if tcpFlags != tcpFlagACK {
                if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
-                       return false
+                       return tcpGROResultNoop
                }
                pshSet = true
        }
        gsoSize := uint16(len(pkt) - tcphLen - iphLen)
        // not a candidate if payload len is 0
        if gsoSize < 1 {
-               return false
+               return tcpGROResultNoop
        }
        seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
        srcAddrOffset := ipv4SrcAddrOffset
@@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
        }
        items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
        if !existing {
-               return false
+               return tcpGROResultNoop
        }
        for i := len(items) - 1; i >= 0; i-- {
                // In the best case of packets arriving in order iterating in reverse is
@@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
                        switch result {
                        case coalesceSuccess:
                                table.updateAt(item, i)
-                               return true
+                               return tcpGROResultCoalesced
                        case coalesceItemInvalidCSum:
                                // delete the item with an invalid csum
                                table.deleteAt(item.key, i)
                        case coalescePktInvalidCSum:
                                // no point in inserting an item that we can't coalesce
-                               return false
+                               return tcpGROResultNoop
                        default:
                        }
                }
        }
        // failed to coalesce with any other packets; store the item in the flow
        table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
-       return false
+       return tcpGROResultTableInsert
 }
 
 func isTCP4NoIPOptions(b []byte) bool {
@@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool {
        return true
 }
 
+// applyCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
+       for _, items := range table.itemsByFlow {
+               for _, item := range items {
+                       if item.numMerged > 0 {
+                               hdr := virtioNetHdr{
+                                       flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+                                       hdrLen:     uint16(item.iphLen + item.tcphLen),
+                                       gsoSize:    item.gsoSize,
+                                       csumStart:  uint16(item.iphLen),
+                                       csumOffset: 16,
+                               }
+                               pkt := bufs[item.bufsIndex][offset:]
+
+                               // Recalculate the total len (IPv4) or payload len (IPv6).
+                               // Recalculate the (IPv4) header checksum.
+                               if isV6 {
+                                       hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+                                       binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+                               } else {
+                                       hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+                                       pkt[10], pkt[11] = 0, 0
+                                       binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+                                       iphCSum := ^checksum(pkt[:item.iphLen], 0)            // compute IPv4 header checksum
+                                       binary.BigEndian.PutUint16(pkt[10:], iphCSum)         // set IPv4 header checksum field
+                               }
+                               err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+                               if err != nil {
+                                       return err
+                               }
+
+                               // Calculate the pseudo header checksum and place it at the TCP
+                               // checksum offset. Downstream checksum offloading will combine
+                               // this with computation of the tcp header and payload checksum.
+                               addrLen := 4
+                               addrOffset := ipv4SrcAddrOffset
+                               if isV6 {
+                                       addrLen = 16
+                                       addrOffset = ipv6SrcAddrOffset
+                               }
+                               srcAddrAt := offset + addrOffset
+                               srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+                               dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+                               psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+                               binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+                       } else {
+                               hdr := virtioNetHdr{}
+                               err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+                               if err != nil {
+                                       return err
+                               }
+                       }
+               }
+       }
+       return nil
+}
+
 // handleGRO evaluates bufs for GRO, and writes the indices of the resulting
 // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
 // empty (but non-nil), and are passed in to save allocs as the caller may reset
@@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW
                if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
                        return errors.New("invalid offset")
                }
-               var coalesced bool
+               var result tcpGROResult
                switch {
                case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
-                       coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
+                       result = tcpGRO(bufs, offset, i, tcp4Table, false)
                case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
-                       coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
+                       result = tcpGRO(bufs, offset, i, tcp6Table, true)
                }
-               if !coalesced {
+               switch result {
+               case tcpGROResultNoop:
                        hdr := virtioNetHdr{}
                        err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
                        if err != nil {
                                return err
                        }
+                       fallthrough
+               case tcpGROResultTableInsert:
                        *toWrite = append(*toWrite, i)
                }
        }
-       return nil
+       err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
+       err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
+       return errors.Join(err4, err6)
 }
 
 // tcpTSO splits packets from in into outBuffs, writing the size of each