]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: reduce size of trie struct
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 11:51:03 +0000 (13:51 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 11:51:03 +0000 (13:51 +0200)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_rand_test.go
device/allowedips_test.go
device/misc.go
device/uapi.go

index b6f096a0a2d1184d34b83d2e67721bec8236df79..1564d2d7f52599a6f718d855b61dd782dff18be9 100644 (file)
@@ -15,13 +15,13 @@ import (
 )
 
 type trieEntry struct {
-       child        [2]*trieEntry
-       peer         *Peer
-       bits         net.IP
-       cidr         uint
-       bit_at_byte  uint
-       bit_at_shift uint
-       perPeerElem  *list.Element
+       peer        *Peer
+       child       [2]*trieEntry
+       cidr        uint8
+       bitAtByte   uint8
+       bitAtShift  uint8
+       bits        net.IP
+       perPeerElem *list.Element
 }
 
 func isLittleEndian() bool {
@@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 {
        return bits.ReverseBytes64(i)
 }
 
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
        size := len(ip1)
        if size == net.IPv4len {
                a := (*uint32)(unsafe.Pointer(&ip1[0]))
                b := (*uint32)(unsafe.Pointer(&ip2[0]))
                x := *a ^ *b
-               return uint(bits.LeadingZeros32(swapU32(x)))
+               return uint8(bits.LeadingZeros32(swapU32(x)))
        } else if size == net.IPv6len {
                a := (*uint64)(unsafe.Pointer(&ip1[0]))
                b := (*uint64)(unsafe.Pointer(&ip2[0]))
                x := *a ^ *b
                if x != 0 {
-                       return uint(bits.LeadingZeros64(swapU64(x)))
+                       return uint8(bits.LeadingZeros64(swapU64(x)))
                }
                a = (*uint64)(unsafe.Pointer(&ip1[8]))
                b = (*uint64)(unsafe.Pointer(&ip2[8]))
                x = *a ^ *b
-               return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+               return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
        } else {
                panic("Wrong size bit string")
        }
@@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
 }
 
 func (node *trieEntry) choose(ip net.IP) byte {
-       return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+       return (ip[node.bitAtByte] >> node.bitAtShift) & 1
 }
 
 func (node *trieEntry) maskSelf() {
@@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() {
        }
 }
 
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
+func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
 
        // at leaf
 
        if node == nil {
                node := &trieEntry{
-                       bits:         ip,
-                       peer:         peer,
-                       cidr:         cidr,
-                       bit_at_byte:  cidr / 8,
-                       bit_at_shift: 7 - (cidr % 8),
+                       bits:       ip,
+                       peer:       peer,
+                       cidr:       cidr,
+                       bitAtByte:  cidr / 8,
+                       bitAtShift: 7 - (cidr % 8),
                }
                node.maskSelf()
                node.addToPeerEntries()
@@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
        // split node
 
        newNode := &trieEntry{
-               bits:         ip,
-               peer:         peer,
-               cidr:         cidr,
-               bit_at_byte:  cidr / 8,
-               bit_at_shift: 7 - (cidr % 8),
+               bits:       ip,
+               peer:       peer,
+               cidr:       cidr,
+               bitAtByte:  cidr / 8,
+               bitAtShift: 7 - (cidr % 8),
        }
        newNode.maskSelf()
        newNode.addToPeerEntries()
 
-       cidr = min(cidr, common)
+       if common < cidr {
+               cidr = common
+       }
 
        // check for shorter prefix
 
@@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
        // create new parent for node & newNode
 
        parent := &trieEntry{
-               bits:         append([]byte{}, ip...),
-               peer:         nil,
-               cidr:         cidr,
-               bit_at_byte:  cidr / 8,
-               bit_at_shift: 7 - (cidr % 8),
+               bits:       append([]byte{}, ip...),
+               peer:       nil,
+               cidr:       cidr,
+               bitAtByte:  cidr / 8,
+               bitAtShift: 7 - (cidr % 8),
        }
        parent.maskSelf()
 
@@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
 
 func (node *trieEntry) lookup(ip net.IP) *Peer {
        var found *Peer
-       size := uint(len(ip))
+       size := uint8(len(ip))
        for node != nil && commonBits(node.bits, ip) >= node.cidr {
                if node.peer != nil {
                        found = node.peer
                }
-               if node.bit_at_byte == size {
+               if node.bitAtByte == size {
                        break
                }
                bit := node.choose(ip)
@@ -208,7 +210,7 @@ type AllowedIPs struct {
        mutex sync.RWMutex
 }
 
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
 
@@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        table.IPv6 = table.IPv6.removeByPeer(peer)
 }
 
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
 
index bb3fb430e5277125333eec8f6838e265b3a18948..2da8795d95d3205dd2b7121a59c53f209b4472ba 100644 (file)
@@ -19,7 +19,7 @@ const (
 
 type SlowNode struct {
        peer *Peer
-       cidr uint
+       cidr uint8
        bits []byte
 }
 
@@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) {
        r[i], r[j] = r[j], r[i]
 }
 
-func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
        for _, t := range r {
                if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
                        t.peer = peer
@@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) {
        for n := 0; n < NumberOfAddresses; n++ {
                var addr [AddressLength]byte
                rand.Read(addr[:])
-               cidr := uint(rand.Uint32() % (AddressLength * 8))
+               cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % NumberOfPeers
                trie = trie.insert(addr[:], cidr, peers[index])
                slow = slow.Insert(addr[:], cidr, peers[index])
@@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) {
        for n := 0; n < NumberOfAddresses; n++ {
                var addr [AddressLength]byte
                rand.Read(addr[:])
-               cidr := uint(rand.Uint32() % (AddressLength * 8))
+               cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % NumberOfPeers
                trie = trie.insert(addr[:], cidr, peers[index])
                slow = slow.Insert(addr[:], cidr, peers[index])
index cdd65cf09e68ff1faec211e1673bbb81c0523efd..8dc84381d437d661032d23f6db7e443de3db1320 100644 (file)
@@ -11,13 +11,10 @@ import (
        "testing"
 )
 
-/* Todo: More comprehensive
- */
-
 type testPairCommonBits struct {
        s1    []byte
        s2    []byte
-       match uint
+       match uint8
 }
 
 func TestCommonBits(t *testing.T) {
@@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
        for n := 0; n < addressNumber; n++ {
                var addr [AddressLength]byte
                rand.Read(addr[:])
-               cidr := uint(rand.Uint32() % (AddressLength * 8))
+               cidr := uint8(rand.Uint32() % (AddressLength * 8))
                index := rand.Int() % peerNumber
                trie = trie.insert(addr[:], cidr, peers[index])
        }
@@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) {
 
        var trie *trieEntry
 
-       insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
+       insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
                trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
        }
 
@@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) {
                return out[:]
        }
 
-       insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+       insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
                var addr []byte
                addr = append(addr, expand(a)...)
                addr = append(addr, expand(b)...)
index 2c2510f71cbaa8177d4c8c5b637e6bdc7874f8f8..4126704ca5c4ae0f5af4a8395432540bdc174bfe 100644 (file)
@@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) {
        }
        atomic.StoreInt32(&a.int32, flag)
 }
-
-func min(a, b uint) uint {
-       if a > b {
-               return b
-       }
-       return a
-}
index 659af0acbcbd14111537862d5baacf2217940e36..66ecd484fcce073138fc4bd72aee9b7719d32db0 100644 (file)
@@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
                        sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
                        sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
 
-                       device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
+                       device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
                                sendf("allowed_ip=%s/%d", ip.String(), cidr)
                                return true
                        })
@@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
                        return nil
                }
                ones, _ := network.Mask.Size()
-               device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
+               device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
 
        case "protocol_version":
                if value != "1" {