]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: make allowedips generic jd/generic-aip
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 17 Mar 2022 01:34:42 +0000 (19:34 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 17 Mar 2022 01:45:10 +0000 (19:45 -0600)
The implementation of commonBits uses a horrific unsafe.Slice trick.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_rand_test.go
device/allowedips_test.go

index 3cac694fb488d525b4a597d994b3b09811bd0647..c36ef3a57e76dbaf03768d56efe4938f06ae1011 100644 (file)
@@ -16,68 +16,86 @@ import (
        "unsafe"
 )
 
-type parentIndirection struct {
-       parentBit     **trieEntry
+type ipArray interface {
+       [4]byte | [16]byte
+}
+
+type parentIndirection[B ipArray] struct {
+       parentBit     **trieEntry[B]
        parentBitType uint8
 }
 
-type trieEntry struct {
+type trieEntry[B ipArray] struct {
        peer        *Peer
-       child       [2]*trieEntry
-       parent      parentIndirection
+       child       [2]*trieEntry[B]
+       parent      parentIndirection[B]
        cidr        uint8
        bitAtByte   uint8
        bitAtShift  uint8
-       bits        []byte
+       bits        B
        perPeerElem *list.Element
 }
 
-func commonBits(ip1, ip2 []byte) uint8 {
-       size := len(ip1)
-       if size == net.IPv4len {
-               a := binary.BigEndian.Uint32(ip1)
-               b := binary.BigEndian.Uint32(ip2)
-               x := a ^ b
-               return uint8(bits.LeadingZeros32(x))
-       } else if size == net.IPv6len {
-               a := binary.BigEndian.Uint64(ip1)
-               b := binary.BigEndian.Uint64(ip2)
-               x := a ^ b
-               if x != 0 {
-                       return uint8(bits.LeadingZeros64(x))
-               }
-               a = binary.BigEndian.Uint64(ip1[8:])
-               b = binary.BigEndian.Uint64(ip2[8:])
-               x = a ^ b
-               return 64 + uint8(bits.LeadingZeros64(x))
-       } else {
-               panic("Wrong size bit string")
+func commonBits4(ip1, ip2 [4]byte) uint8 {
+       a := binary.BigEndian.Uint32(ip1[:])
+       b := binary.BigEndian.Uint32(ip2[:])
+       x := a ^ b
+       return uint8(bits.LeadingZeros32(x))
+}
+
+func commonBits16(ip1, ip2 [16]byte) uint8 {
+       a := binary.BigEndian.Uint64(ip1[:8])
+       b := binary.BigEndian.Uint64(ip2[:8])
+       x := a ^ b
+       if x != 0 {
+               return uint8(bits.LeadingZeros64(x))
+       }
+       a = binary.BigEndian.Uint64(ip1[8:])
+       b = binary.BigEndian.Uint64(ip2[8:])
+       x = a ^ b
+       return 64 + uint8(bits.LeadingZeros64(x))
+}
+
+func giveMeA4[B ipArray](b B) [4]byte {
+       return *(*[4]byte)(unsafe.Slice(&b[0], 4))
+}
+
+func giveMeA16[B ipArray](b B) [16]byte {
+       return *(*[16]byte)(unsafe.Slice(&b[0], 16))
+}
+
+func commonBits[B ipArray](ip1, ip2 B) uint8 {
+       if len(ip1) == 4 {
+               return commonBits4(giveMeA4(ip1), giveMeA4(ip2))
+       } else if len(ip1) == 16 {
+               return commonBits16(giveMeA16(ip1), giveMeA16(ip2))
        }
+       panic("Wrong size bit string")
 }
 
-func (node *trieEntry) addToPeerEntries() {
+func (node *trieEntry[B]) addToPeerEntries() {
        node.perPeerElem = node.peer.trieEntries.PushBack(node)
 }
 
-func (node *trieEntry) removeFromPeerEntries() {
+func (node *trieEntry[B]) removeFromPeerEntries() {
        if node.perPeerElem != nil {
                node.peer.trieEntries.Remove(node.perPeerElem)
                node.perPeerElem = nil
        }
 }
 
-func (node *trieEntry) choose(ip []byte) byte {
+func (node *trieEntry[B]) choose(ip B) byte {
        return (ip[node.bitAtByte] >> node.bitAtShift) & 1
 }
 
-func (node *trieEntry) maskSelf() {
+func (node *trieEntry[B]) maskSelf() {
        mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
        for i := 0; i < len(mask); i++ {
                node.bits[i] &= mask[i]
        }
 }
 
-func (node *trieEntry) zeroizePointers() {
+func (node *trieEntry[B]) zeroizePointers() {
        // Make the garbage collector's life slightly easier
        node.peer = nil
        node.child[0] = nil
@@ -85,7 +103,7 @@ func (node *trieEntry) zeroizePointers() {
        node.parent.parentBit = nil
 }
 
-func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry[B]) nodePlacement(ip B, cidr uint8) (parent *trieEntry[B], exact bool) {
        for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
                parent = node
                if parent.cidr == cidr {
@@ -98,9 +116,9 @@ func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry,
        return
 }
 
-func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
+func (trie parentIndirection[B]) insert(ip B, cidr uint8, peer *Peer) {
        if *trie.parentBit == nil {
-               node := &trieEntry{
+               node := &trieEntry[B]{
                        peer:       peer,
                        parent:     trie,
                        bits:       ip,
@@ -121,7 +139,7 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
                return
        }
 
-       newNode := &trieEntry{
+       newNode := &trieEntry[B]{
                peer:       peer,
                bits:       ip,
                cidr:       cidr,
@@ -131,14 +149,14 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
        newNode.maskSelf()
        newNode.addToPeerEntries()
 
-       var down *trieEntry
+       var down *trieEntry[B]
        if node == nil {
                down = *trie.parentBit
        } else {
                bit := node.choose(ip)
                down = node.child[bit]
                if down == nil {
-                       newNode.parent = parentIndirection{&node.child[bit], bit}
+                       newNode.parent = parentIndirection[B]{&node.child[bit], bit}
                        node.child[bit] = newNode
                        return
                }
@@ -151,21 +169,21 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
 
        if newNode.cidr == cidr {
                bit := newNode.choose(down.bits)
-               down.parent = parentIndirection{&newNode.child[bit], bit}
+               down.parent = parentIndirection[B]{&newNode.child[bit], bit}
                newNode.child[bit] = down
                if parent == nil {
                        newNode.parent = trie
                        *trie.parentBit = newNode
                } else {
                        bit := parent.choose(newNode.bits)
-                       newNode.parent = parentIndirection{&parent.child[bit], bit}
+                       newNode.parent = parentIndirection[B]{&parent.child[bit], bit}
                        parent.child[bit] = newNode
                }
                return
        }
 
-       node = &trieEntry{
-               bits:       append([]byte{}, newNode.bits...),
+       node = &trieEntry[B]{
+               bits:       newNode.bits,
                cidr:       cidr,
                bitAtByte:  cidr / 8,
                bitAtShift: 7 - (cidr % 8),
@@ -173,22 +191,22 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
        node.maskSelf()
 
        bit := node.choose(down.bits)
-       down.parent = parentIndirection{&node.child[bit], bit}
+       down.parent = parentIndirection[B]{&node.child[bit], bit}
        node.child[bit] = down
        bit = node.choose(newNode.bits)
-       newNode.parent = parentIndirection{&node.child[bit], bit}
+       newNode.parent = parentIndirection[B]{&node.child[bit], bit}
        node.child[bit] = newNode
        if parent == nil {
                node.parent = trie
                *trie.parentBit = node
        } else {
                bit := parent.choose(node.bits)
-               node.parent = parentIndirection{&parent.child[bit], bit}
+               node.parent = parentIndirection[B]{&parent.child[bit], bit}
                parent.child[bit] = node
        }
 }
 
-func (node *trieEntry) lookup(ip []byte) *Peer {
+func (node *trieEntry[B]) lookup(ip B) *Peer {
        var found *Peer
        size := uint8(len(ip))
        for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -205,8 +223,8 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
 }
 
 type AllowedIPs struct {
-       IPv4  *trieEntry
-       IPv6  *trieEntry
+       IPv4  *trieEntry[[4]byte]
+       IPv6  *trieEntry[[16]byte]
        mutex sync.RWMutex
 }
 
@@ -215,14 +233,51 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
        defer table.mutex.RUnlock()
 
        for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
-               node := elem.Value.(*trieEntry)
-               a, _ := netip.AddrFromSlice(node.bits)
-               if !cb(netip.PrefixFrom(a, int(node.cidr))) {
-                       return
+               if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
+                       if !cb(netip.PrefixFrom(netip.AddrFrom4(node.bits), int(node.cidr))) {
+                               return
+                       }
+               } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
+                       if !cb(netip.PrefixFrom(netip.AddrFrom16(node.bits), int(node.cidr))) {
+                               return
+                       }
                }
        }
 }
 
+func (node *trieEntry[B]) remove() {
+       node.removeFromPeerEntries()
+       node.peer = nil
+       if node.child[0] != nil && node.child[1] != nil {
+               return
+       }
+       bit := 0
+       if node.child[0] == nil {
+               bit = 1
+       }
+       child := node.child[bit]
+       if child != nil {
+               child.parent = node.parent
+       }
+       *node.parent.parentBit = child
+       if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+               node.zeroizePointers()
+               return
+       }
+       parent := (*trieEntry[B])(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+       if parent.peer != nil {
+               node.zeroizePointers()
+               return
+       }
+       child = parent.child[node.parent.parentBitType^1]
+       if child != nil {
+               child.parent = parent.parent
+       }
+       *parent.parent.parentBit = child
+       node.zeroizePointers()
+       parent.zeroizePointers()
+}
+
 func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
@@ -230,38 +285,11 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        var next *list.Element
        for elem := peer.trieEntries.Front(); elem != nil; elem = next {
                next = elem.Next()
-               node := elem.Value.(*trieEntry)
-
-               node.removeFromPeerEntries()
-               node.peer = nil
-               if node.child[0] != nil && node.child[1] != nil {
-                       continue
-               }
-               bit := 0
-               if node.child[0] == nil {
-                       bit = 1
-               }
-               child := node.child[bit]
-               if child != nil {
-                       child.parent = node.parent
+               if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
+                       node.remove()
+               } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
+                       node.remove()
                }
-               *node.parent.parentBit = child
-               if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
-                       node.zeroizePointers()
-                       continue
-               }
-               parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
-               if parent.peer != nil {
-                       node.zeroizePointers()
-                       continue
-               }
-               child = parent.child[node.parent.parentBitType^1]
-               if child != nil {
-                       child.parent = parent.parent
-               }
-               *parent.parent.parentBit = child
-               node.zeroizePointers()
-               parent.zeroizePointers()
        }
 }
 
@@ -270,11 +298,9 @@ func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
        defer table.mutex.Unlock()
 
        if prefix.Addr().Is6() {
-               ip := prefix.Addr().As16()
-               parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+               parentIndirection[[16]byte]{&table.IPv6, 2}.insert(prefix.Addr().As16(), uint8(prefix.Bits()), peer)
        } else if prefix.Addr().Is4() {
-               ip := prefix.Addr().As4()
-               parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+               parentIndirection[[4]byte]{&table.IPv4, 2}.insert(prefix.Addr().As4(), uint8(prefix.Bits()), peer)
        } else {
                panic(errors.New("inserting unknown address type"))
        }
@@ -285,9 +311,9 @@ func (table *AllowedIPs) Lookup(ip []byte) *Peer {
        defer table.mutex.RUnlock()
        switch len(ip) {
        case net.IPv6len:
-               return table.IPv6.lookup(ip)
+               return table.IPv6.lookup(*(*[16]byte)(ip))
        case net.IPv4len:
-               return table.IPv4.lookup(ip)
+               return table.IPv4.lookup(*(*[4]byte)(ip))
        default:
                panic(errors.New("looking up unknown address type"))
        }
index 0d3eecb067c188f85661cf62558557f386008c79..8c17d025d5ae4992726b8ee373cd399f23351809 100644 (file)
@@ -40,9 +40,18 @@ func (r SlowRouter) Swap(i, j int) {
        r[i], r[j] = r[j], r[i]
 }
 
+func commonBitsSlice(addr1, addr2 []byte) uint8 {
+       if len(addr1) == 4 {
+               return commonBits4(*(*[4]byte)(addr1), *(*[4]byte)(addr2))
+       } else if len(addr1) == 16 {
+               return commonBits16(*(*[16]byte)(addr1), *(*[16]byte)(addr2))
+       }
+       return 0
+}
+
 func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
        for _, t := range r {
-               if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
+               if t.cidr == cidr && commonBitsSlice(t.bits, addr) >= cidr {
                        t.peer = peer
                        t.bits = addr
                        return r
@@ -59,7 +68,7 @@ func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
 
 func (r SlowRouter) Lookup(addr []byte) *Peer {
        for _, t := range r {
-               common := commonBits(t.bits, addr)
+               common := commonBitsSlice(t.bits, addr)
                if common >= t.cidr {
                        return t.peer
                }
index 225c788601d3f8f33fadad87c624805833a71cb8..a0d286fa34c83a4188db4032fe68b29f95e65530 100644 (file)
@@ -7,28 +7,28 @@ package device
 
 import (
        "math/rand"
-       "net"
        "net/netip"
        "testing"
+       "unsafe"
 )
 
-type testPairCommonBits struct {
-       s1    []byte
-       s2    []byte
+type testPairCommonBits4 struct {
+       s1    [4]byte
+       s2    [4]byte
        match uint8
 }
 
-func TestCommonBits(t *testing.T) {
-       tests := []testPairCommonBits{
-               {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
-               {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
-               {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
-               {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
-               {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+func TestCommonBits4(t *testing.T) {
+       tests := []testPairCommonBits4{
+               {s1: [4]byte{1, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 7},
+               {s1: [4]byte{0, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 13},
+               {s1: [4]byte{0, 4, 53, 253}, s2: [4]byte{0, 4, 53, 252}, match: 31},
+               {s1: [4]byte{192, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 15},
+               {s1: [4]byte{65, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 0},
        }
 
        for _, p := range tests {
-               v := commonBits(p.s1, p.s2)
+               v := commonBits4(p.s1, p.s2)
                if v != p.match {
                        t.Error(
                                "For slice", p.s1, p.s2,
@@ -39,48 +39,46 @@ func TestCommonBits(t *testing.T) {
        }
 }
 
-func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
-       var trie *trieEntry
+func benchmarkTrie[B ipArray](peerNumber, addressNumber int, b *testing.B) {
+       var trie *trieEntry[B]
        var peers []*Peer
-       root := parentIndirection{&trie, 2}
+       root := parentIndirection[B]{&trie, 2}
 
        rand.Seed(1)
 
-       const AddressLength = 4
-
        for n := 0; n < peerNumber; n++ {
                peers = append(peers, &Peer{})
        }
 
        for n := 0; n < addressNumber; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               cidr := uint8(rand.Uint32() % (AddressLength * 8))
+               var addr B
+               rand.Read(unsafe.Slice(&addr[0], len(addr)))
+               cidr := uint8(rand.Uint32() % uint32(len(addr)*8))
                index := rand.Int() % peerNumber
-               root.insert(addr[:], cidr, peers[index])
+               root.insert(addr, cidr, peers[index])
        }
 
        for n := 0; n < b.N; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               trie.lookup(addr[:])
+               var addr B
+               rand.Read(unsafe.Slice(&addr[0], len(addr)))
+               trie.lookup(addr)
        }
 }
 
 func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
-       benchmarkTrie(100, 1000, net.IPv4len, b)
+       benchmarkTrie[[4]byte](100, 1000, b)
 }
 
 func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
-       benchmarkTrie(10, 10, net.IPv4len, b)
+       benchmarkTrie[[4]byte](10, 10, b)
 }
 
 func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
-       benchmarkTrie(100, 1000, net.IPv6len, b)
+       benchmarkTrie[[16]byte](100, 1000, b)
 }
 
 func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
-       benchmarkTrie(10, 10, net.IPv6len, b)
+       benchmarkTrie[[16]byte](10, 10, b)
 }
 
 /* Test ported from kernel implementation: