package tun
-import "encoding/binary"
+import (
+ "encoding/binary"
+ "math/bits"
+)
// TODO: Explore SIMD and/or other assembly optimizations.
-// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
- ac := initial
+ tmp := make([]byte, 8)
+ binary.NativeEndian.PutUint64(tmp, initial)
+ ac := binary.BigEndian.Uint64(tmp)
+ var carry uint64
for len(b) >= 128 {
- ac += uint64(binary.BigEndian.Uint32(b[:4]))
- ac += uint64(binary.BigEndian.Uint32(b[4:8]))
- ac += uint64(binary.BigEndian.Uint32(b[8:12]))
- ac += uint64(binary.BigEndian.Uint32(b[12:16]))
- ac += uint64(binary.BigEndian.Uint32(b[16:20]))
- ac += uint64(binary.BigEndian.Uint32(b[20:24]))
- ac += uint64(binary.BigEndian.Uint32(b[24:28]))
- ac += uint64(binary.BigEndian.Uint32(b[28:32]))
- ac += uint64(binary.BigEndian.Uint32(b[32:36]))
- ac += uint64(binary.BigEndian.Uint32(b[36:40]))
- ac += uint64(binary.BigEndian.Uint32(b[40:44]))
- ac += uint64(binary.BigEndian.Uint32(b[44:48]))
- ac += uint64(binary.BigEndian.Uint32(b[48:52]))
- ac += uint64(binary.BigEndian.Uint32(b[52:56]))
- ac += uint64(binary.BigEndian.Uint32(b[56:60]))
- ac += uint64(binary.BigEndian.Uint32(b[60:64]))
- ac += uint64(binary.BigEndian.Uint32(b[64:68]))
- ac += uint64(binary.BigEndian.Uint32(b[68:72]))
- ac += uint64(binary.BigEndian.Uint32(b[72:76]))
- ac += uint64(binary.BigEndian.Uint32(b[76:80]))
- ac += uint64(binary.BigEndian.Uint32(b[80:84]))
- ac += uint64(binary.BigEndian.Uint32(b[84:88]))
- ac += uint64(binary.BigEndian.Uint32(b[88:92]))
- ac += uint64(binary.BigEndian.Uint32(b[92:96]))
- ac += uint64(binary.BigEndian.Uint32(b[96:100]))
- ac += uint64(binary.BigEndian.Uint32(b[100:104]))
- ac += uint64(binary.BigEndian.Uint32(b[104:108]))
- ac += uint64(binary.BigEndian.Uint32(b[108:112]))
- ac += uint64(binary.BigEndian.Uint32(b[112:116]))
- ac += uint64(binary.BigEndian.Uint32(b[116:120]))
- ac += uint64(binary.BigEndian.Uint32(b[120:124]))
- ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
+ ac += carry
b = b[128:]
}
if len(b) >= 64 {
- ac += uint64(binary.BigEndian.Uint32(b[:4]))
- ac += uint64(binary.BigEndian.Uint32(b[4:8]))
- ac += uint64(binary.BigEndian.Uint32(b[8:12]))
- ac += uint64(binary.BigEndian.Uint32(b[12:16]))
- ac += uint64(binary.BigEndian.Uint32(b[16:20]))
- ac += uint64(binary.BigEndian.Uint32(b[20:24]))
- ac += uint64(binary.BigEndian.Uint32(b[24:28]))
- ac += uint64(binary.BigEndian.Uint32(b[28:32]))
- ac += uint64(binary.BigEndian.Uint32(b[32:36]))
- ac += uint64(binary.BigEndian.Uint32(b[36:40]))
- ac += uint64(binary.BigEndian.Uint32(b[40:44]))
- ac += uint64(binary.BigEndian.Uint32(b[44:48]))
- ac += uint64(binary.BigEndian.Uint32(b[48:52]))
- ac += uint64(binary.BigEndian.Uint32(b[52:56]))
- ac += uint64(binary.BigEndian.Uint32(b[56:60]))
- ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
+ ac += carry
b = b[64:]
}
if len(b) >= 32 {
- ac += uint64(binary.BigEndian.Uint32(b[:4]))
- ac += uint64(binary.BigEndian.Uint32(b[4:8]))
- ac += uint64(binary.BigEndian.Uint32(b[8:12]))
- ac += uint64(binary.BigEndian.Uint32(b[12:16]))
- ac += uint64(binary.BigEndian.Uint32(b[16:20]))
- ac += uint64(binary.BigEndian.Uint32(b[20:24]))
- ac += uint64(binary.BigEndian.Uint32(b[24:28]))
- ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
+ ac += carry
b = b[32:]
}
if len(b) >= 16 {
- ac += uint64(binary.BigEndian.Uint32(b[:4]))
- ac += uint64(binary.BigEndian.Uint32(b[4:8]))
- ac += uint64(binary.BigEndian.Uint32(b[8:12]))
- ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
+ ac += carry
b = b[16:]
}
if len(b) >= 8 {
- ac += uint64(binary.BigEndian.Uint32(b[:4]))
- ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
+ ac += carry
b = b[8:]
}
if len(b) >= 4 {
- ac += uint64(binary.BigEndian.Uint32(b))
+ ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
+ ac += carry
b = b[4:]
}
if len(b) >= 2 {
- ac += uint64(binary.BigEndian.Uint16(b))
+ ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
+ ac += carry
b = b[2:]
}
if len(b) == 1 {
- ac += uint64(b[0]) << 8
+ tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
+ ac, carry = bits.Add64(ac, uint64(tmp), 0)
+ ac += carry
}
- return ac
+ binary.NativeEndian.PutUint64(tmp, ac)
+ return binary.BigEndian.Uint64(tmp)
}
func checksum(b []byte, initial uint64) uint16 {
package tun
import (
+ "encoding/binary"
"fmt"
"math/rand"
"testing"
+
+ "golang.org/x/sys/unix"
)
+func checksumRef(b []byte, initial uint16) uint16 {
+ ac := uint64(initial)
+
+ for len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
+ for (ac >> 16) > 0 {
+ ac = (ac >> 16) + (ac & 0xffff)
+ }
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
+ sum := checksumRef(srcAddr, 0)
+ sum = checksumRef(dstAddr, sum)
+ sum = checksumRef([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumRef(tmp, sum)
+}
+
+func TestChecksum(t *testing.T) {
+ for length := 0; length <= 9001; length++ {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ csum := checksum(buf, 0x1234)
+ csumRef := checksumRef(buf, 0x1234)
+ if csum != csumRef {
+ t.Error("Expected checksum", csumRef, "got", csum)
+ }
+ }
+}
+
+func TestPseudoHeaderChecksum(t *testing.T) {
+ for _, addrLen := range []int{4, 16} {
+ for length := 0; length <= 9001; length++ {
+ srcAddr := make([]byte, addrLen)
+ dstAddr := make([]byte, addrLen)
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(srcAddr)
+ rng.Read(dstAddr)
+ rng.Read(buf)
+ phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
+ csum := checksum(buf, phSum)
+ phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
+ csumRef := checksumRef(buf, phSumRef)
+ if csum != csumRef {
+ t.Error("Expected checksumRef", csumRef, "got", csum)
+ }
+ }
+ }
+}
+
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,