]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: remove nodes by peer in O(1) instead of O(n)
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 13:40:09 +0000 (15:40 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 14:29:43 +0000 (16:29 +0200)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_rand_test.go

index d613121c462ccdf4e0c38dc618801d0f197e81f5..7af9fc7b8035c4326b4890b32f3abca73cdaa078 100644 (file)
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
        }
 }
 
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
-       if node == nil {
-               return node
-       }
-
-       // walk recursively
-
-       node.child[0] = node.child[0].removeByPeer(p)
-       node.child[1] = node.child[1].removeByPeer(p)
-
-       if node.peer != p {
-               return node
-       }
-
-       // remove peer & merge
-
-       node.removeFromPeerEntries()
-       node.peer = nil
-       if node.child[0] == nil {
-               return node.child[1]
-       }
-       return node.child[0]
-}
-
 func (node *trieEntry) choose(ip net.IP) byte {
        return (ip[node.bitAtByte] >> node.bitAtShift) & 1
 }
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
 
-       table.IPv4 = table.IPv4.removeByPeer(peer)
-       table.IPv6 = table.IPv6.removeByPeer(peer)
+       var next *list.Element
+       for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+               next = elem.Next()
+               node := elem.Value.(*trieEntry)
+
+               node.removeFromPeerEntries()
+               node.peer = nil
+               if node.child[0] != nil && node.child[1] != nil {
+                       continue
+               }
+               bit := 0
+               if node.child[0] == nil {
+                       bit = 1
+               }
+               child := node.child[bit]
+               if child != nil {
+                       child.parent = node.parent
+               }
+               *node.parent.parentBit = child
+               if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+                       continue
+               }
+               parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+               if parent.peer != nil {
+                       continue
+               }
+               child = parent.child[node.parent.parentBitType^1]
+               if child != nil {
+                       child.parent = parent.parent
+               }
+               *parent.parent.parentBit = child
+       }
 }
 
 func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
index 48a5bcde6b0d4815bc8f7f7f12620f0c7742e6bf..c5f80fe937859ea72a60c57287b0de181a377ce3 100644 (file)
@@ -7,6 +7,7 @@ package device
 
 import (
        "math/rand"
+       "net"
        "sort"
        "testing"
 )
@@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
        return nil
 }
 
-func TestTrieRandomIPv4(t *testing.T) {
-       var slow SlowRouter
-       var peers []*Peer
-       var allowedIPs AllowedIPs
-
-       rand.Seed(1)
-
-       const AddressLength = 4
-
-       for n := 0; n < NumberOfPeers; n++ {
-               peers = append(peers, &Peer{})
-       }
-
-       for n := 0; n < NumberOfAddresses; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               cidr := uint8(rand.Uint32() % (AddressLength * 8))
-               index := rand.Int() % NumberOfPeers
-               allowedIPs.Insert(addr[:], cidr, peers[index])
-               slow = slow.Insert(addr[:], cidr, peers[index])
-       }
-
-       for n := 0; n < NumberOfTests; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               peer1 := slow.Lookup(addr[:])
-               peer2 := allowedIPs.LookupIPv4(addr[:])
-               if peer1 != peer2 {
-                       t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+       n := 0
+       for _, x := range r {
+               if x.peer != peer {
+                       r[n] = x
+                       n++
                }
        }
+       return r[:n]
 }
 
-func TestTrieRandomIPv6(t *testing.T) {
-       var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+       var slow4, slow6 SlowRouter
        var peers []*Peer
        var allowedIPs AllowedIPs
 
        rand.Seed(1)
 
-       const AddressLength = 16
-
        for n := 0; n < NumberOfPeers; n++ {
                peers = append(peers, &Peer{})
        }
 
        for n := 0; n < NumberOfAddresses; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               cidr := uint8(rand.Uint32() % (AddressLength * 8))
-               index := rand.Int() % NumberOfPeers
-               allowedIPs.Insert(addr[:], cidr, peers[index])
-               slow = slow.Insert(addr[:], cidr, peers[index])
+               var addr4 [4]byte
+               rand.Read(addr4[:])
+               cidr := uint8(rand.Intn(32) + 1)
+               index := rand.Intn(NumberOfPeers)
+               allowedIPs.Insert(addr4[:], cidr, peers[index])
+               slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+               var addr6 [16]byte
+               rand.Read(addr6[:])
+               cidr = uint8(rand.Intn(128) + 1)
+               index = rand.Intn(NumberOfPeers)
+               allowedIPs.Insert(addr6[:], cidr, peers[index])
+               slow6 = slow6.Insert(addr6[:], cidr, peers[index])
        }
 
-       for n := 0; n < NumberOfTests; n++ {
-               var addr [AddressLength]byte
-               rand.Read(addr[:])
-               peer1 := slow.Lookup(addr[:])
-               peer2 := allowedIPs.LookupIPv6(addr[:])
-               if peer1 != peer2 {
-                       t.Error("Trie did not match naive implementation, for:", addr)
+       for p := 0; ; p++ {
+               for n := 0; n < NumberOfTests; n++ {
+                       var addr4 [4]byte
+                       rand.Read(addr4[:])
+                       peer1 := slow4.Lookup(addr4[:])
+                       peer2 := allowedIPs.LookupIPv4(addr4[:])
+                       if peer1 != peer2 {
+                               t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+                       }
+
+                       var addr6 [16]byte
+                       rand.Read(addr6[:])
+                       peer1 = slow6.Lookup(addr6[:])
+                       peer2 = allowedIPs.LookupIPv6(addr6[:])
+                       if peer1 != peer2 {
+                               t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+                       }
+               }
+               if p >= len(peers) {
+                       break
                }
+               allowedIPs.RemoveByPeer(peers[p])
+               slow4 = slow4.RemoveByPeer(peers[p])
+               slow6 = slow6.RemoveByPeer(peers[p])
+       }
+
+       if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+               t.Error("Failed to remove all nodes from trie by peer")
        }
 }