]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Better common bits function
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 13:49:20 +0000 (15:49 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 14 May 2018 13:49:20 +0000 (15:49 +0200)
allowedips.go
allowedips_test.go

index df53abf5998e660768875ddcc988b169a6cc270b..e700dc4746f6c893248e110b61bf6c88f995ad06 100644 (file)
@@ -7,8 +7,10 @@ package main
 
 import (
        "errors"
+       "math/bits"
        "net"
        "sync"
+       "unsafe"
 )
 
 type trieEntry struct {
@@ -23,62 +25,48 @@ type trieEntry struct {
        bit_at_shift uint
 }
 
-/* Finds length of matching prefix
- *
- * TODO: Only use during insertion (xor + prefix mask for lookup)
- *       Check out
- *       prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
- *       https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
- *
- * Assumption:
- *       len(ip1) == len(ip2)
- *       len(ip1) mod 4 = 0
- */
-func commonBits(ip1 []byte, ip2 []byte) uint {
-       var i uint
-       size := uint(len(ip1))
-
-       for i = 0; i < size; i++ {
-               v := ip1[i] ^ ip2[i]
-               if v != 0 {
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 7
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 6
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 5
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 4
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 3
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 2
-                       }
-
-                       v >>= 1
-                       if v == 0 {
-                               return i*8 + 1
-                       }
-                       return i * 8
+func isLittleEndian() bool {
+       one := uint32(1)
+       return *(*byte)(unsafe.Pointer(&one)) != 0
+}
+
+func swapU32(i uint32) uint32 {
+       if !isLittleEndian() {
+               return i
+       }
+
+       return bits.ReverseBytes32(i)
+}
+
+func swapU64(i uint64) uint64 {
+       if !isLittleEndian() {
+               return i
+       }
+
+       return bits.ReverseBytes64(i)
+}
+
+func commonBits(ip1 net.IP, ip2 net.IP) uint {
+       size := len(ip1)
+       if size == net.IPv4len {
+               a := (*uint32)(unsafe.Pointer(&ip1[0]))
+               b := (*uint32)(unsafe.Pointer(&ip2[0]))
+               x := *a ^ *b
+               return uint(bits.LeadingZeros32(swapU32(x)))
+       } else if size == net.IPv6len {
+               a := (*uint64)(unsafe.Pointer(&ip1[0]))
+               b := (*uint64)(unsafe.Pointer(&ip2[0]))
+               x := *a ^ *b
+               if x != 0 {
+                       return uint(bits.LeadingZeros64(swapU64(x)))
                }
+               a = (*uint64)(unsafe.Pointer(&ip1[8]))
+               b = (*uint64)(unsafe.Pointer(&ip2[8]))
+               x = *a ^ *b
+               return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+       } else {
+               panic("Wrong size bit string")
        }
-       return i * 8
 }
 
 func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
index 7b73af3385912fb53d1bed837b748066f0d74f62..ce21cb4a90b9d3f5c7ea0479b852cd4f57fc4488 100644 (file)
@@ -106,7 +106,7 @@ func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
 }
 
 /* Test ported from kernel implementation:
- * selftest/routingtable.h
+ * selftest/allowedips.h
  */
 func TestTrieIPv4(t *testing.T) {
        a := &Peer{}
@@ -192,7 +192,7 @@ func TestTrieIPv4(t *testing.T) {
 }
 
 /* Test ported from kernel implementation:
- * selftest/routingtable.h
+ * selftest/allowedips.h
  */
 func TestTrieIPv6(t *testing.T) {
        a := &Peer{}