]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: use correct IP header comparisons in tcpGRO() and tcpPacketsCanCoalesce()
authorJordan Whited <jordan@tailscale.com>
Fri, 24 Mar 2023 23:23:42 +0000 (16:23 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sat, 25 Mar 2023 22:13:38 +0000 (23:13 +0100)
tcpGRO() was using an incorrect IPv4 more fragments bit mask.

tcpPacketsCanCoalesce() was not distinguishing tcp6 from tcp4, and TTL
values were not compared. TTL values should be equal at the IP layer,
otherwise the packets should not coalesce. This tracks with the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/tcp_offload_linux.go
tun/tcp_offload_linux_test.go

index 4912efd3f75acf08af8ee6953ea8434a83d9c705..39a7180c5a2a50f65901e630fb2e12f02e6dc7d6 100644 (file)
@@ -189,14 +189,29 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
                        return coalesceUnavailable
                }
        }
-       if pkt[1] != pktTarget[1] {
-               // cannot coalesce with unequal ToS values
-               return coalesceUnavailable
-       }
-       if pkt[6]>>5 != pktTarget[6]>>5 {
-               // cannot coalesce with unequal DF or reserved bits. MF is checked
-               // further up the stack.
-               return coalesceUnavailable
+       if pkt[0]>>4 == 6 {
+               if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
+                       // cannot coalesce with unequal Traffic class values
+                       return coalesceUnavailable
+               }
+               if pkt[7] != pktTarget[7] {
+                       // cannot coalesce with unequal Hop limit values
+                       return coalesceUnavailable
+               }
+       } else {
+               if pkt[1] != pktTarget[1] {
+                       // cannot coalesce with unequal ToS values
+                       return coalesceUnavailable
+               }
+               if pkt[6]>>5 != pktTarget[6]>>5 {
+                       // cannot coalesce with unequal DF or reserved bits. MF is checked
+                       // further up the stack.
+                       return coalesceUnavailable
+               }
+               if pkt[8] != pktTarget[8] {
+                       // cannot coalesce with unequal TTL values
+                       return coalesceUnavailable
+               }
        }
        // seq adjacency
        lhsLen := item.gsoSize
@@ -366,7 +381,7 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
 }
 
 const (
-       ipv4FlagMoreFragments = 0x80
+       ipv4FlagMoreFragments uint8 = 0x20
 )
 
 const (
@@ -409,7 +424,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
                return false
        }
        if !isV6 {
-               if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) {
+               if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
                        // no GRO support for fragmented segments for now
                        return false
                }
index 046e177e04e4aa0c6c07fe93f20d7dc3bee17fc9..9160e18cdd941ef0ed88b47863d8e6ef92c5a88c 100644 (file)
@@ -28,19 +28,23 @@ var (
        ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
 )
 
-func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
        totalLen := 40 + segmentSize
        b := make([]byte, offset+int(totalLen), 65535)
        ipv4H := header.IPv4(b[offset:])
        srcAs4 := srcIPPort.Addr().As4()
        dstAs4 := dstIPPort.Addr().As4()
-       ipv4H.Encode(&header.IPv4Fields{
+       ipFields := &header.IPv4Fields{
                SrcAddr:     tcpip.Address(srcAs4[:]),
                DstAddr:     tcpip.Address(dstAs4[:]),
                Protocol:    unix.IPPROTO_TCP,
                TTL:         64,
                TotalLength: uint16(totalLen),
-       })
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv4H.Encode(ipFields)
        tcpH := header.TCP(b[offset+20:])
        tcpH.Encode(&header.TCPFields{
                SrcPort:    srcIPPort.Port(),
@@ -57,19 +61,27 @@ func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm
        return b
 }
 
-func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+       return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
        totalLen := 60 + segmentSize
        b := make([]byte, offset+int(totalLen), 65535)
        ipv6H := header.IPv6(b[offset:])
        srcAs16 := srcIPPort.Addr().As16()
        dstAs16 := dstIPPort.Addr().As16()
-       ipv6H.Encode(&header.IPv6Fields{
+       ipFields := &header.IPv6Fields{
                SrcAddr:           tcpip.Address(srcAs16[:]),
                DstAddr:           tcpip.Address(dstAs16[:]),
                TransportProtocol: unix.IPPROTO_TCP,
                HopLimit:          64,
                PayloadLength:     uint16(segmentSize + 20),
-       })
+       }
+       if ipFn != nil {
+               ipFn(ipFields)
+       }
+       ipv6H.Encode(ipFields)
        tcpH := header.TCP(b[offset+40:])
        tcpH.Encode(&header.TCPFields{
                SrcPort:    srcIPPort.Port(),
@@ -85,6 +97,10 @@ func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm
        return b
 }
 
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+       return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
 func Test_handleVirtioRead(t *testing.T) {
        tests := []struct {
                name     string
@@ -245,6 +261,78 @@ func Test_handleGRO(t *testing.T) {
                        []int{340},
                        false,
                },
+               {
+                       "tcp4 unequal TTL",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.TTL++
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{140, 140},
+                       false,
+               },
+               {
+                       "tcp4 unequal ToS",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.TOS++
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{140, 140},
+                       false,
+               },
+               {
+                       "tcp4 unequal flags more fragments set",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 1
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{140, 140},
+                       false,
+               },
+               {
+                       "tcp4 unequal flags DF set",
+                       [][]byte{
+                               tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+                               tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+                                       fields.Flags = 2
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{140, 140},
+                       false,
+               },
+               {
+                       "tcp6 unequal hop limit",
+                       [][]byte{
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+                                       fields.HopLimit++
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{160, 160},
+                       false,
+               },
+               {
+                       "tcp6 unequal traffic class",
+                       [][]byte{
+                               tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+                               tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+                                       fields.TrafficClass++
+                               }),
+                       },
+                       []int{0, 1},
+                       []int{160, 160},
+                       false,
+               },
        }
 
        for _, tt := range tests {