]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: use linked list for per-peer allowed-ip traversal
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 26 Jan 2021 22:44:37 +0000 (23:44 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 27 Jan 2021 00:48:58 +0000 (01:48 +0100)
This makes the IpcGet method much faster.

We also refactor the traversal API to use a callback so that we don't
need to allocate at all. Avoiding allocations we do self-masking on
insertion, which in turn means that split intermediate nodes require a
copy of the bits.

benchmark               old ns/op     new ns/op     delta
BenchmarkUAPIGet-16     3243          2659          -18.01%

benchmark               old allocs     new allocs     delta
BenchmarkUAPIGet-16     35             30             -14.29%

benchmark               old bytes     new bytes     delta
BenchmarkUAPIGet-16     1218          737           -39.49%

This benchmark is good, though it's only for a pair of peers, each with
only one allowedips. As this grows, the delta expands considerably.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/device.go
device/peer.go
device/uapi.go

index 143bda35c53d2c3608e4b2090552dbeb12f7930a..b5e40e96c1d6a496e5469680396bec15058fb767 100644 (file)
@@ -14,15 +14,14 @@ import (
 )
 
 type trieEntry struct {
-       cidr  uint
-       child [2]*trieEntry
-       bits  net.IP
-       peer  *Peer
-
-       // index of "branching" bit
-
-       bit_at_byte  uint
-       bit_at_shift uint
+       child             [2]*trieEntry
+       peer              *Peer
+       bits              net.IP
+       cidr              uint
+       bit_at_byte       uint
+       bit_at_shift      uint
+       nextEntryForPeer  *trieEntry
+       pprevEntryForPeer **trieEntry
 }
 
 func isLittleEndian() bool {
@@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint {
        }
 }
 
+func (node *trieEntry) addToPeerEntries() {
+       p := node.peer
+       first := p.firstTrieEntry
+       node.nextEntryForPeer = first
+       if first != nil {
+               first.pprevEntryForPeer = &node.nextEntryForPeer
+       }
+       p.firstTrieEntry = node
+       node.pprevEntryForPeer = &p.firstTrieEntry
+}
+
+func (node *trieEntry) removeFromPeerEntries() {
+       if node.pprevEntryForPeer == nil {
+               return
+       }
+       next := node.nextEntryForPeer
+       pprev := node.pprevEntryForPeer
+       *pprev = next
+       if next != nil {
+               next.pprevEntryForPeer = pprev
+       }
+       node.nextEntryForPeer = nil
+       node.pprevEntryForPeer = nil
+}
+
 func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
        if node == nil {
                return node
@@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
 
        // remove peer & merge
 
+       node.removeFromPeerEntries()
        node.peer = nil
        if node.child[0] == nil {
                return node.child[1]
@@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte {
        return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
 }
 
+func (node *trieEntry) maskSelf() {
+       mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+       for i := 0; i < len(mask); i++ {
+               node.bits[i] &= mask[i]
+       }
+}
+
 func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
 
        // at leaf
 
        if node == nil {
-               return &trieEntry{
+               node := &trieEntry{
                        bits:         ip,
                        peer:         peer,
                        cidr:         cidr,
                        bit_at_byte:  cidr / 8,
                        bit_at_shift: 7 - (cidr % 8),
                }
+               node.maskSelf()
+               node.addToPeerEntries()
+               return node
        }
 
        // traverse deeper
@@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
        common := commonBits(node.bits, ip)
        if node.cidr <= cidr && common >= node.cidr {
                if node.cidr == cidr {
+                       node.removeFromPeerEntries()
                        node.peer = peer
+                       node.addToPeerEntries()
                        return node
                }
                bit := node.choose(ip)
@@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
                bit_at_byte:  cidr / 8,
                bit_at_shift: 7 - (cidr % 8),
        }
+       newNode.maskSelf()
+       newNode.addToPeerEntries()
 
        cidr = min(cidr, common)
 
@@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
        // create new parent for node & newNode
 
        parent := &trieEntry{
-               bits:         ip,
+               bits:         append([]byte{}, ip...),
                peer:         nil,
                cidr:         cidr,
                bit_at_byte:  cidr / 8,
                bit_at_shift: 7 - (cidr % 8),
        }
+       parent.maskSelf()
 
        bit := parent.choose(ip)
        parent.child[bit] = newNode
@@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
        return found
 }
 
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
-       if node == nil {
-               return results
-       }
-       if node.peer == p {
-               mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
-               results = append(results, net.IPNet{
-                       Mask: mask,
-                       IP:   node.bits.Mask(mask),
-               })
-       }
-       results = node.child[0].entriesForPeer(p, results)
-       results = node.child[1].entriesForPeer(p, results)
-       return results
-}
-
 type AllowedIPs struct {
        IPv4  *trieEntry
        IPv6  *trieEntry
        mutex sync.RWMutex
 }
 
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
 
-       allowed := make([]net.IPNet, 0, 10)
-       allowed = table.IPv4.entriesForPeer(peer, allowed)
-       allowed = table.IPv6.entriesForPeer(peer, allowed)
-       return allowed
-}
-
-func (table *AllowedIPs) Reset() {
-       table.mutex.Lock()
-       defer table.mutex.Unlock()
-
-       table.IPv4 = nil
-       table.IPv6 = nil
+       for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
+               if !cb(node.bits, node.cidr) {
+                       return
+               }
+       }
 }
 
 func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
index ebcbd9efd0fef9cf15e5bde9177686969fdd38d4..47c49446b81d141a7bcf17842c7bbabc01307761 100644 (file)
@@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
        device.rate.underLoadUntil.Store(time.Time{})
 
        device.indexTable.Init()
-       device.allowedips.Reset()
 
        device.PopulatePools()
 
index 5324ae471031ca36efc201b1850abe08b6e09724..a103b5d0e739ed4446facd3eeecdd19fbaa0db88 100644 (file)
@@ -28,6 +28,7 @@ type Peer struct {
        device                      *Device
        endpoint                    conn.Endpoint
        persistentKeepaliveInterval uint32 // accessed atomically
+       firstTrieEntry              *trieEntry
 
        // These fields are accessed with atomic operations, which must be
        // 64-bit aligned even on 32-bit platforms. Go guarantees that an
index 148a7a2706ba08753b488e98ab26adf512e4ecaf..cbfe25ef631313c50092ce79e0614a51572a46d0 100644 (file)
@@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
                        sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
                        sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
 
-                       for _, ip := range device.allowedips.EntriesForPeer(peer) {
-                               sendf("allowed_ip=%s", ip.String())
-                       }
+                       device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
+                               sendf("allowed_ip=%s/%d", ip.String(), cidr)
+                               return true
+                       })
                }
        }()