]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: remove recursion from insertion and connect parent pointers
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 12:50:28 +0000 (14:50 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 13:08:42 +0000 (15:08 +0200)
This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_rand_test.go
device/allowedips_test.go

index 1564d2d7f52599a6f718d855b61dd782dff18be9..d613121c462ccdf4e0c38dc618801d0f197e81f5 100644 (file)
@@ -14,9 +14,15 @@ import (
        "unsafe"
 )
 
+type parentIndirection struct {
+       parentBit     **trieEntry
+       parentBitType uint8
+}
+
 type trieEntry struct {
        peer        *Peer
        child       [2]*trieEntry
+       parent      parentIndirection
        cidr        uint8
        bitAtByte   uint8
        bitAtShift  uint8
@@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
        }
 }
 
-func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
-
-       // at leaf
+func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+       for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
+               parent = node
+               if parent.cidr == cidr {
+                       exact = true
+                       return
+               }
+               bit := node.choose(ip)
+               node = node.child[bit]
+       }
+       return
+}
 
-       if node == nil {
+func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+       if *trie.parentBit == nil {
                node := &trieEntry{
-                       bits:       ip,
                        peer:       peer,
+                       parent:     trie,
+                       bits:       ip,
                        cidr:       cidr,
                        bitAtByte:  cidr / 8,
                        bitAtShift: 7 - (cidr % 8),
                }
                node.maskSelf()
                node.addToPeerEntries()
-               return node
+               *trie.parentBit = node
+               return
        }
-
-       // traverse deeper
-
-       common := commonBits(node.bits, ip)
-       if node.cidr <= cidr && common >= node.cidr {
-               if node.cidr == cidr {
-                       node.removeFromPeerEntries()
-                       node.peer = peer
-                       node.addToPeerEntries()
-                       return node
-               }
-               bit := node.choose(ip)
-               node.child[bit] = node.child[bit].insert(ip, cidr, peer)
-               return node
+       node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
+       if exact {
+               node.removeFromPeerEntries()
+               node.peer = peer
+               node.addToPeerEntries()
+               return
        }
 
-       // split node
-
        newNode := &trieEntry{
-               bits:       ip,
                peer:       peer,
+               bits:       ip,
                cidr:       cidr,
                bitAtByte:  cidr / 8,
                bitAtShift: 7 - (cidr % 8),
@@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
        newNode.maskSelf()
        newNode.addToPeerEntries()
 
+       var down *trieEntry
+       if node == nil {
+               down = *trie.parentBit
+       } else {
+               bit := node.choose(ip)
+               down = node.child[bit]
+               if down == nil {
+                       newNode.parent = parentIndirection{&node.child[bit], bit}
+                       node.child[bit] = newNode
+                       return
+               }
+       }
+       common := commonBits(down.bits, ip)
        if common < cidr {
                cidr = common
        }
-
-       // check for shorter prefix
+       parent := node
 
        if newNode.cidr == cidr {
-               bit := newNode.choose(node.bits)
-               newNode.child[bit] = node
-               return newNode
+               bit := newNode.choose(down.bits)
+               down.parent = parentIndirection{&newNode.child[bit], bit}
+               newNode.child[bit] = down
+               if parent == nil {
+                       newNode.parent = trie
+                       *trie.parentBit = newNode
+               } else {
+                       bit := parent.choose(newNode.bits)
+                       newNode.parent = parentIndirection{&parent.child[bit], bit}
+                       parent.child[bit] = newNode
+               }
+               return
        }
 
-       // create new parent for node & newNode
-
-       parent := &trieEntry{
-               bits:       append([]byte{}, ip...),
-               peer:       nil,
+       node = &trieEntry{
+               bits:       append([]byte{}, newNode.bits...),
                cidr:       cidr,
                bitAtByte:  cidr / 8,
                bitAtShift: 7 - (cidr % 8),
        }
-       parent.maskSelf()
-
-       bit := parent.choose(ip)
-       parent.child[bit] = newNode
-       parent.child[bit^1] = node
-
-       return parent
+       node.maskSelf()
+
+       bit := node.choose(down.bits)
+       down.parent = parentIndirection{&node.child[bit], bit}
+       node.child[bit] = down
+       bit = node.choose(newNode.bits)
+       newNode.parent = parentIndirection{&node.child[bit], bit}
+       node.child[bit] = newNode
+       if parent == nil {
+               node.parent = trie
+               *trie.parentBit = node
+       } else {
+               bit := parent.choose(node.bits)
+               node.parent = parentIndirection{&parent.child[bit], bit}
+               parent.child[bit] = node
+       }
 }
 
 func (node *trieEntry) lookup(ip net.IP) *Peer {
@@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
 
        switch len(ip) {
        case net.IPv6len:
-               table.IPv6 = table.IPv6.insert(ip, cidr, peer)
+               parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
        case net.IPv4len:
-               table.IPv4 = table.IPv4.insert(ip, cidr, peer)
+               parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
        default:
                panic(errors.New("inserting unknown address type"))
        }
index 2da8795d95d3205dd2b7121a59c53f209b4472ba..48a5bcde6b0d4815bc8f7f7f12620f0c7742e6bf 100644 (file)
@@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
 }
 
 func TestTrieRandomIPv4(t *testing.T) {
-       var trie *trieEntry
        var slow SlowRouter
        var peers []*Peer
+       var allowedIPs AllowedIPs
 
        rand.Seed(1)
 
@@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
                rand.Read(addr[:])
                cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % NumberOfPeers
-               trie = trie.insert(addr[:], cidr, peers[index])
+               allowedIPs.Insert(addr[:], cidr, peers[index])
                slow = slow.Insert(addr[:], cidr, peers[index])
        }
 
@@ -90,7 +90,7 @@ func TestTrieRandomIPv4(t *testing.T) {
                var addr [AddressLength]byte
                rand.Read(addr[:])
                peer1 := slow.Lookup(addr[:])
-               peer2 := trie.lookup(addr[:])
+               peer2 := allowedIPs.LookupIPv4(addr[:])
                if peer1 != peer2 {
                        t.Error("Trie did not match naive implementation, for:", addr)
                }
@@ -98,9 +98,9 @@ func TestTrieRandomIPv4(t *testing.T) {
 }
 
 func TestTrieRandomIPv6(t *testing.T) {
-       var trie *trieEntry
        var slow SlowRouter
        var peers []*Peer
+       var allowedIPs AllowedIPs
 
        rand.Seed(1)
 
@@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
                rand.Read(addr[:])
                cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % NumberOfPeers
-               trie = trie.insert(addr[:], cidr, peers[index])
+               allowedIPs.Insert(addr[:], cidr, peers[index])
                slow = slow.Insert(addr[:], cidr, peers[index])
        }
 
@@ -123,7 +123,7 @@ func TestTrieRandomIPv6(t *testing.T) {
                var addr [AddressLength]byte
                rand.Read(addr[:])
                peer1 := slow.Lookup(addr[:])
-               peer2 := trie.lookup(addr[:])
+               peer2 := allowedIPs.LookupIPv6(addr[:])
                if peer1 != peer2 {
                        t.Error("Trie did not match naive implementation, for:", addr)
                }
index 8dc84381d437d661032d23f6db7e443de3db1320..cbd32cc5fea7ad847cd5880702dc84ec16a7a30d 100644 (file)
@@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) {
 func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
        var trie *trieEntry
        var peers []*Peer
+       root := parentIndirection{&trie, 2}
 
        rand.Seed(1)
 
@@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
                rand.Read(addr[:])
                cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % peerNumber
-               trie = trie.insert(addr[:], cidr, peers[index])
+               root.insert(addr[:], cidr, peers[index])
        }
 
        for n := 0; n < b.N; n++ {
@@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
        g := &Peer{}
        h := &Peer{}
 
-       var trie *trieEntry
+       var allowedIPs AllowedIPs
 
        insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
-               trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+               allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
        }
 
        assertEQ := func(peer *Peer, a, b, c, d byte) {
-               p := trie.lookup([]byte{a, b, c, d})
+               p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
                if p != peer {
                        t.Error("Assert EQ failed")
                }
        }
 
        assertNEQ := func(peer *Peer, a, b, c, d byte) {
-               p := trie.lookup([]byte{a, b, c, d})
+               p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
                if p == peer {
                        t.Error("Assert NEQ failed")
                }
@@ -150,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
        assertEQ(a, 192, 0, 0, 0)
        assertEQ(a, 255, 0, 0, 0)
 
-       trie = trie.removeByPeer(a)
+       allowedIPs.RemoveByPeer(a)
 
        assertNEQ(a, 1, 0, 0, 0)
        assertNEQ(a, 64, 0, 0, 0)
@@ -158,12 +159,12 @@ func TestTrieIPv4(t *testing.T) {
        assertNEQ(a, 192, 0, 0, 0)
        assertNEQ(a, 255, 0, 0, 0)
 
-       trie = nil
+       allowedIPs = AllowedIPs{}
 
        insert(a, 192, 168, 0, 0, 16)
        insert(a, 192, 168, 0, 0, 24)
 
-       trie = trie.removeByPeer(a)
+       allowedIPs.RemoveByPeer(a)
 
        assertNEQ(a, 192, 168, 0, 1)
 }
@@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) {
        g := &Peer{}
        h := &Peer{}
 
-       var trie *trieEntry
+       var allowedIPs AllowedIPs
 
        expand := func(a uint32) []byte {
                var out [4]byte
@@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) {
                addr = append(addr, expand(b)...)
                addr = append(addr, expand(c)...)
                addr = append(addr, expand(d)...)
-               trie = trie.insert(addr, cidr, peer)
+               allowedIPs.Insert(addr, cidr, peer)
        }
 
        assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
                addr = append(addr, expand(b)...)
                addr = append(addr, expand(c)...)
                addr = append(addr, expand(d)...)
-               p := trie.lookup(addr)
+               p := allowedIPs.LookupIPv6(addr)
                if p != peer {
                        t.Error("Assert EQ failed")
                }