]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
netfilter: nft_set_rbtree: use binary search array in get command
authorPablo Neira Ayuso <pablo@netfilter.org>
Wed, 21 Jan 2026 00:08:46 +0000 (01:08 +0100)
committerFlorian Westphal <fw@strlen.de>
Thu, 22 Jan 2026 16:18:13 +0000 (17:18 +0100)
Rework .get interface to use the binary search array, this needs a specific
lookup function to match on end intervals (<=). Packet path lookup is slight
different because match is on lesser value, not equal (ie. <).

After this patch, seqcount can be removed in a follow up patch.

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
Signed-off-by: Florian Westphal <fw@strlen.de>
net/netfilter/nft_set_rbtree.c

index 821808e8da060647cb1aee75b2be46047902bbaf..de2cce96023e95a5ea9ac8f2a5465986f29d1487 100644 (file)
@@ -62,96 +62,6 @@ static int nft_rbtree_cmp(const struct nft_set *set,
                      set->klen);
 }
 
-static bool __nft_rbtree_get(const struct net *net, const struct nft_set *set,
-                            const u32 *key, struct nft_rbtree_elem **elem,
-                            unsigned int seq, unsigned int flags, u8 genmask)
-{
-       struct nft_rbtree_elem *rbe, *interval = NULL;
-       struct nft_rbtree *priv = nft_set_priv(set);
-       const struct rb_node *parent;
-       const void *this;
-       int d;
-
-       parent = rcu_dereference_raw(priv->root.rb_node);
-       while (parent != NULL) {
-               if (read_seqcount_retry(&priv->count, seq))
-                       return false;
-
-               rbe = rb_entry(parent, struct nft_rbtree_elem, node);
-
-               this = nft_set_ext_key(&rbe->ext);
-               d = memcmp(this, key, set->klen);
-               if (d < 0) {
-                       parent = rcu_dereference_raw(parent->rb_left);
-                       if (!(flags & NFT_SET_ELEM_INTERVAL_END))
-                               interval = rbe;
-               } else if (d > 0) {
-                       parent = rcu_dereference_raw(parent->rb_right);
-                       if (flags & NFT_SET_ELEM_INTERVAL_END)
-                               interval = rbe;
-               } else {
-                       if (!nft_set_elem_active(&rbe->ext, genmask)) {
-                               parent = rcu_dereference_raw(parent->rb_left);
-                               continue;
-                       }
-
-                       if (nft_set_elem_expired(&rbe->ext))
-                               return false;
-
-                       if (!nft_set_ext_exists(&rbe->ext, NFT_SET_EXT_FLAGS) ||
-                           (*nft_set_ext_flags(&rbe->ext) & NFT_SET_ELEM_INTERVAL_END) ==
-                           (flags & NFT_SET_ELEM_INTERVAL_END)) {
-                               *elem = rbe;
-                               return true;
-                       }
-
-                       if (nft_rbtree_interval_end(rbe))
-                               interval = NULL;
-
-                       parent = rcu_dereference_raw(parent->rb_left);
-               }
-       }
-
-       if (set->flags & NFT_SET_INTERVAL && interval != NULL &&
-           nft_set_elem_active(&interval->ext, genmask) &&
-           !nft_set_elem_expired(&interval->ext) &&
-           ((!nft_rbtree_interval_end(interval) &&
-             !(flags & NFT_SET_ELEM_INTERVAL_END)) ||
-            (nft_rbtree_interval_end(interval) &&
-             (flags & NFT_SET_ELEM_INTERVAL_END)))) {
-               *elem = interval;
-               return true;
-       }
-
-       return false;
-}
-
-static struct nft_elem_priv *
-nft_rbtree_get(const struct net *net, const struct nft_set *set,
-              const struct nft_set_elem *elem, unsigned int flags)
-{
-       struct nft_rbtree *priv = nft_set_priv(set);
-       unsigned int seq = read_seqcount_begin(&priv->count);
-       struct nft_rbtree_elem *rbe = ERR_PTR(-ENOENT);
-       const u32 *key = (const u32 *)&elem->key.val;
-       u8 genmask = nft_genmask_cur(net);
-       bool ret;
-
-       ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask);
-       if (ret || !read_seqcount_retry(&priv->count, seq))
-               return &rbe->priv;
-
-       read_lock_bh(&priv->lock);
-       seq = read_seqcount_begin(&priv->count);
-       ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask);
-       read_unlock_bh(&priv->lock);
-
-       if (!ret)
-               return ERR_PTR(-ENOENT);
-
-       return &rbe->priv;
-}
-
 struct nft_array_lookup_ctx {
        const u32       *key;
        u32             klen;
@@ -206,6 +116,70 @@ nft_rbtree_lookup(const struct net *net, const struct nft_set *set,
        return interval->from;
 }
 
+struct nft_array_get_ctx {
+       const u32       *key;
+       unsigned int    flags;
+       u32             klen;
+};
+
+static int nft_array_get_cmp(const void *pkey, const void *entry)
+{
+       const struct nft_array_interval *interval = entry;
+       const struct nft_array_get_ctx *ctx = pkey;
+       int a, b;
+
+       if (!interval->from)
+               return 1;
+
+       a = memcmp(ctx->key, nft_set_ext_key(interval->from), ctx->klen);
+       if (!interval->to)
+               b = -1;
+       else
+               b = memcmp(ctx->key, nft_set_ext_key(interval->to), ctx->klen);
+
+       if (a >= 0) {
+               if (ctx->flags & NFT_SET_ELEM_INTERVAL_END && b <= 0)
+                       return 0;
+               else if (b < 0)
+                       return 0;
+       }
+
+       if (a < 0)
+               return -1;
+
+       return 1;
+}
+
+static struct nft_elem_priv *
+nft_rbtree_get(const struct net *net, const struct nft_set *set,
+              const struct nft_set_elem *elem, unsigned int flags)
+{
+       struct nft_rbtree *priv = nft_set_priv(set);
+       struct nft_array *array = rcu_dereference(priv->array);
+       const struct nft_array_interval *interval;
+       struct nft_array_get_ctx ctx = {
+               .key    = (const u32 *)&elem->key.val,
+               .flags  = flags,
+               .klen   = set->klen,
+       };
+       struct nft_rbtree_elem *rbe;
+
+       if (!array)
+               return ERR_PTR(-ENOENT);
+
+       interval = bsearch(&ctx, array->intervals, array->num_intervals,
+                          sizeof(struct nft_array_interval), nft_array_get_cmp);
+       if (!interval || nft_set_elem_expired(interval->from))
+               return ERR_PTR(-ENOENT);
+
+       if (flags & NFT_SET_ELEM_INTERVAL_END)
+               rbe = container_of(interval->to, struct nft_rbtree_elem, ext);
+       else
+               rbe = container_of(interval->from, struct nft_rbtree_elem, ext);
+
+       return &rbe->priv;
+}
+
 static void nft_rbtree_gc_elem_remove(struct net *net, struct nft_set *set,
                                      struct nft_rbtree *priv,
                                      struct nft_rbtree_elem *rbe)