]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: use add-with-carry in checksumNoFold()
authorTu Dinh Ngoc <dinhngoc.tu@irit.fr>
Thu, 20 Jun 2024 13:28:38 +0000 (13:28 +0000)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 5 May 2025 13:10:08 +0000 (15:10 +0200)
Use parallel summation with native byte order per RFC 1071.
add-with-carry operation is used to add 4 words per operation.  Byteswap
is performed before and after checksumming for compatibility with old
`checksumNoFold()`.  With this we get a 30-80% speedup in `checksum()`
depending on packet sizes.

Add unit tests with comparison to a per-word implementation.

**Intel(R) Xeon(R) Silver 4210R CPU @ 2.40GHz**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 12.64   | 9.183   | 1.376456 |
| 128  | 18.52   | 12.72   | 1.455975 |
| 256  | 31.01   | 18.13   | 1.710425 |
| 512  | 54.46   | 29.03   | 1.87599  |
| 1024 | 102     | 52.2    | 1.954023 |
| 1500 | 146.8   | 81.36   | 1.804326 |
| 2048 | 196.9   | 102.5   | 1.920976 |
| 4096 | 389.8   | 200.8   | 1.941235 |
| 8192 | 767.3   | 413.3   | 1.856521 |
| 9000 | 851.7   | 448.8   | 1.897727 |
| 9001 | 854.8   | 451.9   | 1.891569 |

**AMD EPYC 7352 24-Core Processor**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 9.159   | 6.949   | 1.318031 |
| 128  | 13.59   | 10.59   | 1.283286 |
| 256  | 22.37   | 14.91   | 1.500335 |
| 512  | 41.42   | 24.22   | 1.710157 |
| 1024 | 81.59   | 45.05   | 1.811099 |
| 1500 | 120.4   | 68.35   | 1.761522 |
| 2048 | 162.8   | 90.14   | 1.806079 |
| 4096 | 321.4   | 180.3   | 1.782585 |
| 8192 | 650.4   | 360.8   | 1.802661 |
| 9000 | 706.3   | 398.1   | 1.774177 |
| 9001 | 712.4   | 398.2   | 1.789051 |

Signed-off-by: Tu Dinh Ngoc <dinhngoc.tu@irit.fr>
[Jason: simplified and cleaned up unit tests]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/checksum.go
tun/checksum_test.go

index 29a8fc8fc0fe0d7e9a824bc76e79437f666fa473..b489c56f5eb464a010d91e68f0f90ff55196bf9b 100644 (file)
 package tun
 
-import "encoding/binary"
+import (
+       "encoding/binary"
+       "math/bits"
+)
 
 // TODO: Explore SIMD and/or other assembly optimizations.
-// TODO: Test native endian loads. See RFC 1071 section 2 part B.
 func checksumNoFold(b []byte, initial uint64) uint64 {
-       ac := initial
+       tmp := make([]byte, 8)
+       binary.NativeEndian.PutUint64(tmp, initial)
+       ac := binary.BigEndian.Uint64(tmp)
+       var carry uint64
 
        for len(b) >= 128 {
-               ac += uint64(binary.BigEndian.Uint32(b[:4]))
-               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
-               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
-               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
-               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
-               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
-               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
-               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
-               ac += uint64(binary.BigEndian.Uint32(b[32:36]))
-               ac += uint64(binary.BigEndian.Uint32(b[36:40]))
-               ac += uint64(binary.BigEndian.Uint32(b[40:44]))
-               ac += uint64(binary.BigEndian.Uint32(b[44:48]))
-               ac += uint64(binary.BigEndian.Uint32(b[48:52]))
-               ac += uint64(binary.BigEndian.Uint32(b[52:56]))
-               ac += uint64(binary.BigEndian.Uint32(b[56:60]))
-               ac += uint64(binary.BigEndian.Uint32(b[60:64]))
-               ac += uint64(binary.BigEndian.Uint32(b[64:68]))
-               ac += uint64(binary.BigEndian.Uint32(b[68:72]))
-               ac += uint64(binary.BigEndian.Uint32(b[72:76]))
-               ac += uint64(binary.BigEndian.Uint32(b[76:80]))
-               ac += uint64(binary.BigEndian.Uint32(b[80:84]))
-               ac += uint64(binary.BigEndian.Uint32(b[84:88]))
-               ac += uint64(binary.BigEndian.Uint32(b[88:92]))
-               ac += uint64(binary.BigEndian.Uint32(b[92:96]))
-               ac += uint64(binary.BigEndian.Uint32(b[96:100]))
-               ac += uint64(binary.BigEndian.Uint32(b[100:104]))
-               ac += uint64(binary.BigEndian.Uint32(b[104:108]))
-               ac += uint64(binary.BigEndian.Uint32(b[108:112]))
-               ac += uint64(binary.BigEndian.Uint32(b[112:116]))
-               ac += uint64(binary.BigEndian.Uint32(b[116:120]))
-               ac += uint64(binary.BigEndian.Uint32(b[120:124]))
-               ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
+               ac += carry
                b = b[128:]
        }
        if len(b) >= 64 {
-               ac += uint64(binary.BigEndian.Uint32(b[:4]))
-               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
-               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
-               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
-               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
-               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
-               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
-               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
-               ac += uint64(binary.BigEndian.Uint32(b[32:36]))
-               ac += uint64(binary.BigEndian.Uint32(b[36:40]))
-               ac += uint64(binary.BigEndian.Uint32(b[40:44]))
-               ac += uint64(binary.BigEndian.Uint32(b[44:48]))
-               ac += uint64(binary.BigEndian.Uint32(b[48:52]))
-               ac += uint64(binary.BigEndian.Uint32(b[52:56]))
-               ac += uint64(binary.BigEndian.Uint32(b[56:60]))
-               ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
+               ac += carry
                b = b[64:]
        }
        if len(b) >= 32 {
-               ac += uint64(binary.BigEndian.Uint32(b[:4]))
-               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
-               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
-               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
-               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
-               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
-               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
-               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+               ac += carry
                b = b[32:]
        }
        if len(b) >= 16 {
-               ac += uint64(binary.BigEndian.Uint32(b[:4]))
-               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
-               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
-               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+               ac += carry
                b = b[16:]
        }
        if len(b) >= 8 {
-               ac += uint64(binary.BigEndian.Uint32(b[:4]))
-               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+               ac += carry
                b = b[8:]
        }
        if len(b) >= 4 {
-               ac += uint64(binary.BigEndian.Uint32(b))
+               ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
+               ac += carry
                b = b[4:]
        }
        if len(b) >= 2 {
-               ac += uint64(binary.BigEndian.Uint16(b))
+               ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
+               ac += carry
                b = b[2:]
        }
        if len(b) == 1 {
-               ac += uint64(b[0]) << 8
+               tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
+               ac, carry = bits.Add64(ac, uint64(tmp), 0)
+               ac += carry
        }
 
-       return ac
+       binary.NativeEndian.PutUint64(tmp, ac)
+       return binary.BigEndian.Uint64(tmp)
 }
 
 func checksum(b []byte, initial uint64) uint16 {
index c1ccff531d062a7342bca90edb80ff38bbc6d83a..4ea9b8b52dd943e4c83a335fbbcb3ec2266a6b51 100644 (file)
@@ -1,11 +1,74 @@
 package tun
 
 import (
+       "encoding/binary"
        "fmt"
        "math/rand"
        "testing"
+
+       "golang.org/x/sys/unix"
 )
 
+func checksumRef(b []byte, initial uint16) uint16 {
+       ac := uint64(initial)
+
+       for len(b) >= 2 {
+               ac += uint64(binary.BigEndian.Uint16(b))
+               b = b[2:]
+       }
+       if len(b) == 1 {
+               ac += uint64(b[0]) << 8
+       }
+
+       for (ac >> 16) > 0 {
+               ac = (ac >> 16) + (ac & 0xffff)
+       }
+       return uint16(ac)
+}
+
+func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
+       sum := checksumRef(srcAddr, 0)
+       sum = checksumRef(dstAddr, sum)
+       sum = checksumRef([]byte{0, protocol}, sum)
+       tmp := make([]byte, 2)
+       binary.BigEndian.PutUint16(tmp, totalLen)
+       return checksumRef(tmp, sum)
+}
+
+func TestChecksum(t *testing.T) {
+       for length := 0; length <= 9001; length++ {
+               buf := make([]byte, length)
+               rng := rand.New(rand.NewSource(1))
+               rng.Read(buf)
+               csum := checksum(buf, 0x1234)
+               csumRef := checksumRef(buf, 0x1234)
+               if csum != csumRef {
+                       t.Error("Expected checksum", csumRef, "got", csum)
+               }
+       }
+}
+
+func TestPseudoHeaderChecksum(t *testing.T) {
+       for _, addrLen := range []int{4, 16} {
+               for length := 0; length <= 9001; length++ {
+                       srcAddr := make([]byte, addrLen)
+                       dstAddr := make([]byte, addrLen)
+                       buf := make([]byte, length)
+                       rng := rand.New(rand.NewSource(1))
+                       rng.Read(srcAddr)
+                       rng.Read(dstAddr)
+                       rng.Read(buf)
+                       phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
+                       csum := checksum(buf, phSum)
+                       phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
+                       csumRef := checksumRef(buf, phSumRef)
+                       if csum != csumRef {
+                               t.Error("Expected checksumRef", csumRef, "got", csum)
+                       }
+               }
+       }
+}
+
 func BenchmarkChecksum(b *testing.B) {
        lengths := []int{
                64,