]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: add support for removing allowedips individually
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 20 May 2025 21:03:06 +0000 (23:03 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 20 May 2025 21:03:06 +0000 (23:03 +0200)
This pairs with the recent change in wireguard-tools.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_test.go
device/uapi.go

index b40c8170c795bdac97e5e48cb011bfd0ea61bb6d..d15373cfef0b2a8fa0585f62d1d4acddfe273c0f 100644 (file)
@@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
        }
 }
 
+func (node *trieEntry) remove() {
+       node.removeFromPeerEntries()
+       node.peer = nil
+       if node.child[0] != nil && node.child[1] != nil {
+               return
+       }
+       bit := 0
+       if node.child[0] == nil {
+               bit = 1
+       }
+       child := node.child[bit]
+       if child != nil {
+               child.parent = node.parent
+       }
+       *node.parent.parentBit = child
+       if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+               node.zeroizePointers()
+               return
+       }
+       parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+       if parent.peer != nil {
+               node.zeroizePointers()
+               return
+       }
+       child = parent.child[node.parent.parentBitType^1]
+       if child != nil {
+               child.parent = parent.parent
+       }
+       *parent.parent.parentBit = child
+       node.zeroizePointers()
+       parent.zeroizePointers()
+}
+
+func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
+       table.mutex.Lock()
+       defer table.mutex.Unlock()
+       var node *trieEntry
+       var exact bool
+
+       if prefix.Addr().Is6() {
+               ip := prefix.Addr().As16()
+               node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
+       } else if prefix.Addr().Is4() {
+               ip := prefix.Addr().As4()
+               node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
+       } else {
+               panic(errors.New("removing unknown address type"))
+       }
+       if !exact || node == nil || peer != node.peer {
+               return
+       }
+       node.remove()
+}
+
 func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
@@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        var next *list.Element
        for elem := peer.trieEntries.Front(); elem != nil; elem = next {
                next = elem.Next()
-               node := elem.Value.(*trieEntry)
-
-               node.removeFromPeerEntries()
-               node.peer = nil
-               if node.child[0] != nil && node.child[1] != nil {
-                       continue
-               }
-               bit := 0
-               if node.child[0] == nil {
-                       bit = 1
-               }
-               child := node.child[bit]
-               if child != nil {
-                       child.parent = node.parent
-               }
-               *node.parent.parentBit = child
-               if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
-                       node.zeroizePointers()
-                       continue
-               }
-               parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
-               if parent.peer != nil {
-                       node.zeroizePointers()
-                       continue
-               }
-               child = parent.child[node.parent.parentBitType^1]
-               if child != nil {
-                       child.parent = parent.parent
-               }
-               *parent.parent.parentBit = child
-               node.zeroizePointers()
-               parent.zeroizePointers()
+               elem.Value.(*trieEntry).remove()
        }
 }
 
index 7df7da5b8ba7a3543ab3435d75f4db955019f34f..a4b08a399ca85d11a6e324ec09e9447f38546f15 100644 (file)
@@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
                allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
        }
 
+       remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+               allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
+       }
+
        assertEQ := func(peer *Peer, a, b, c, d byte) {
                p := allowedIPs.Lookup([]byte{a, b, c, d})
                if p != peer {
@@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
        allowedIPs.RemoveByPeer(a)
 
        assertNEQ(a, 192, 168, 0, 1)
+
+       insert(a, 1, 0, 0, 0, 32)
+       insert(a, 192, 0, 0, 0, 24)
+       assertEQ(a, 1, 0, 0, 0)
+       assertEQ(a, 192, 0, 0, 1)
+       remove(a, 192, 0, 0, 0, 32)
+       assertEQ(a, 192, 0, 0, 1)
+       remove(nil, 192, 0, 0, 0, 24)
+       assertEQ(a, 192, 0, 0, 1)
+       remove(b, 192, 0, 0, 0, 24)
+       assertEQ(a, 192, 0, 0, 1)
+       remove(a, 192, 0, 0, 0, 24)
+       assertNEQ(a, 192, 0, 0, 1)
+       remove(a, 1, 0, 0, 0, 32)
+       assertNEQ(a, 1, 0, 0, 0)
 }
 
 /* Test ported from kernel implementation:
@@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
                allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
        }
 
+       remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
+               var addr []byte
+               addr = append(addr, expand(a)...)
+               addr = append(addr, expand(b)...)
+               addr = append(addr, expand(c)...)
+               addr = append(addr, expand(d)...)
+               allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
+       }
+
        assertEQ := func(peer *Peer, a, b, c, d uint32) {
                var addr []byte
                addr = append(addr, expand(a)...)
@@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
                }
        }
 
+       assertNEQ := func(peer *Peer, a, b, c, d uint32) {
+               var addr []byte
+               addr = append(addr, expand(a)...)
+               addr = append(addr, expand(b)...)
+               addr = append(addr, expand(c)...)
+               addr = append(addr, expand(d)...)
+               p := allowedIPs.Lookup(addr)
+               if p == peer {
+                       t.Error("Assert NEQ failed")
+               }
+       }
+
        insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
        insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
        insert(e, 0, 0, 0, 0, 0)
@@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
        assertEQ(h, 0x24046800, 0x40040800, 0, 0)
        assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
        assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
+
+       insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+       insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+       assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+       assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+       remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
+       assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+       remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+       assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+       remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+       assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+       remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+       assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+       remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+       assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+       remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+       assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
 }
index 521a7411c2ebf46158d22c6a3aefe52e5fb0ff8f..cc69488b4e2e3b6a957a62ce59c8cd65506eb2f7 100644 (file)
@@ -371,7 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
                device.allowedips.RemoveByPeer(peer.Peer)
 
        case "allowed_ip":
-               device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
+               add := true
+               verb := "Adding"
+               if len(value) > 0 && value[0] == '-' {
+                       add = false
+                       verb = "Removing"
+                       value = value[1:]
+               }
+               device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
                prefix, err := netip.ParsePrefix(value)
                if err != nil {
                        return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
@@ -379,7 +386,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
                if peer.dummy {
                        return nil
                }
-               device.allowedips.Insert(prefix, peer.Peer)
+               if add {
+                       device.allowedips.Insert(prefix, peer.Peer)
+               } else {
+                       device.allowedips.Remove(prefix, peer.Peer)
+               }
 
        case "protocol_version":
                if value != "1" {