]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun: unwind summing loop in checksumNoFold()
authorJordan Whited <jordan@tailscale.com>
Mon, 2 Oct 2023 21:43:56 +0000 (14:43 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 10 Oct 2023 13:07:36 +0000 (15:07 +0200)
$ benchstat old.txt new.txt
goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/tun
cpu: 12th Gen Intel(R) Core(TM) i5-12400
                 │   old.txt    │               new.txt               │
                 │    sec/op    │   sec/op     vs base                │
Checksum/64-12     10.670n ± 2%   4.769n ± 0%  -55.30% (p=0.000 n=10)
Checksum/128-12    19.665n ± 2%   8.032n ± 0%  -59.16% (p=0.000 n=10)
Checksum/256-12     37.68n ± 1%   16.06n ± 0%  -57.37% (p=0.000 n=10)
Checksum/512-12     76.61n ± 3%   32.13n ± 0%  -58.06% (p=0.000 n=10)
Checksum/1024-12   160.55n ± 4%   64.25n ± 0%  -59.98% (p=0.000 n=10)
Checksum/1500-12   231.05n ± 7%   94.12n ± 0%  -59.26% (p=0.000 n=10)
Checksum/2048-12    309.5n ± 3%   128.5n ± 0%  -58.48% (p=0.000 n=10)
Checksum/4096-12    603.8n ± 4%   257.2n ± 0%  -57.41% (p=0.000 n=10)
Checksum/8192-12   1185.0n ± 3%   515.5n ± 0%  -56.50% (p=0.000 n=10)
Checksum/9000-12   1328.5n ± 5%   564.8n ± 0%  -57.49% (p=0.000 n=10)
Checksum/9001-12   1340.5n ± 3%   564.8n ± 0%  -57.87% (p=0.000 n=10)
geomean             185.3n        77.99n       -57.92%

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/checksum.go
tun/checksum_test.go [new file with mode: 0644]

index f4f847164a073db82fee5e83860f03aaac46ec3b..29a8fc8fc0fe0d7e9a824bc76e79437f666fa473 100644 (file)
@@ -3,23 +3,99 @@ package tun
 import "encoding/binary"
 
 // TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
 func checksumNoFold(b []byte, initial uint64) uint64 {
        ac := initial
-       i := 0
-       n := len(b)
-       for n >= 4 {
-               ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
-               n -= 4
-               i += 4
+
+       for len(b) >= 128 {
+               ac += uint64(binary.BigEndian.Uint32(b[:4]))
+               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+               ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+               ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+               ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+               ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+               ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+               ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+               ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+               ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+               ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+               ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+               ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+               ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+               ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+               ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+               ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+               ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+               ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+               ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+               ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+               ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+               ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+               ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+               ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+               ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+               b = b[128:]
+       }
+       if len(b) >= 64 {
+               ac += uint64(binary.BigEndian.Uint32(b[:4]))
+               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+               ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+               ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+               ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+               ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+               ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+               ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+               ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+               ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+               b = b[64:]
+       }
+       if len(b) >= 32 {
+               ac += uint64(binary.BigEndian.Uint32(b[:4]))
+               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+               ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+               ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+               ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+               ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+               b = b[32:]
+       }
+       if len(b) >= 16 {
+               ac += uint64(binary.BigEndian.Uint32(b[:4]))
+               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+               ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+               b = b[16:]
        }
-       for n >= 2 {
-               ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
-               n -= 2
-               i += 2
+       if len(b) >= 8 {
+               ac += uint64(binary.BigEndian.Uint32(b[:4]))
+               ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+               b = b[8:]
        }
-       if n == 1 {
-               ac += uint64(b[i]) << 8
+       if len(b) >= 4 {
+               ac += uint64(binary.BigEndian.Uint32(b))
+               b = b[4:]
        }
+       if len(b) >= 2 {
+               ac += uint64(binary.BigEndian.Uint16(b))
+               b = b[2:]
+       }
+       if len(b) == 1 {
+               ac += uint64(b[0]) << 8
+       }
+
        return ac
 }
 
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644 (file)
index 0000000..c1ccff5
--- /dev/null
@@ -0,0 +1,35 @@
+package tun
+
+import (
+       "fmt"
+       "math/rand"
+       "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+       lengths := []int{
+               64,
+               128,
+               256,
+               512,
+               1024,
+               1500,
+               2048,
+               4096,
+               8192,
+               9000,
+               9001,
+       }
+
+       for _, length := range lengths {
+               b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+                       buf := make([]byte, length)
+                       rng := rand.New(rand.NewSource(1))
+                       rng.Read(buf)
+                       b.ResetTimer()
+                       for i := 0; i < b.N; i++ {
+                               checksum(buf, 0)
+                       }
+               })
+       }
+}