]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: zero out allowedip node pointers when removing
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 4 Jun 2021 14:33:28 +0000 (16:33 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 4 Jun 2021 14:33:28 +0000 (16:33 +0200)
This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_test.go

index 95615ab3592149d89eb4fcb92792d2fc252b59e3..c08399bbf69efb0f841db0a121338fb14753064d 100644 (file)
@@ -96,6 +96,14 @@ func (node *trieEntry) maskSelf() {
        }
 }
 
+func (node *trieEntry) zeroizePointers() {
+       // Make the garbage collector's life slightly easier
+       node.peer = nil
+       node.child[0] = nil
+       node.child[1] = nil
+       node.parent.parentBit = nil
+}
+
 func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
        for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
                parent = node
@@ -257,10 +265,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
                }
                *node.parent.parentBit = child
                if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+                       node.zeroizePointers()
                        continue
                }
                parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
                if parent.peer != nil {
+                       node.zeroizePointers()
                        continue
                }
                child = parent.child[node.parent.parentBitType^1]
@@ -268,6 +278,8 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
                        child.parent = parent.parent
                }
                *parent.parent.parentBit = child
+               node.zeroizePointers()
+               parent.zeroizePointers()
        }
 }
 
index 7701cde450d864c6161d6a0c031ea2ef4a3b3284..2059a8836d0f4da2c6e411d01e8c899e31acf2ba 100644 (file)
@@ -159,7 +159,16 @@ func TestTrieIPv4(t *testing.T) {
        assertNEQ(a, 192, 0, 0, 0)
        assertNEQ(a, 255, 0, 0, 0)
 
-       allowedIPs = AllowedIPs{}
+       allowedIPs.RemoveByPeer(a)
+       allowedIPs.RemoveByPeer(b)
+       allowedIPs.RemoveByPeer(c)
+       allowedIPs.RemoveByPeer(d)
+       allowedIPs.RemoveByPeer(e)
+       allowedIPs.RemoveByPeer(g)
+       allowedIPs.RemoveByPeer(h)
+       if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+               t.Error("Expected removing all the peers to empty trie, but it did not")
+       }
 
        insert(a, 192, 168, 0, 0, 16)
        insert(a, 192, 168, 0, 0, 24)