--- /dev/null
+From 628bd3e49cba1c066228e23d71a852c23e26da73 Mon Sep 17 00:00:00 2001
+From: Pablo Neira Ayuso <pablo@netfilter.org>
+Date: Fri, 16 Jun 2023 14:51:49 +0200
+Subject: netfilter: nf_tables: drop map element references from preparation phase
+
+From: Pablo Neira Ayuso <pablo@netfilter.org>
+
+commit 628bd3e49cba1c066228e23d71a852c23e26da73 upstream.
+
+set .destroy callback releases the references to other objects in maps.
+This is very late and it results in spurious EBUSY errors. Drop refcount
+from the preparation phase instead, update set backend not to drop
+reference counter from set .destroy path.
+
+Exceptions: NFT_TRANS_PREPARE_ERROR does not require to drop the
+reference counter because the transaction abort path releases the map
+references for each element since the set is unbound. The abort path
+also deals with releasing reference counter for new elements added to
+unbound sets.
+
+Fixes: 591054469b3e ("netfilter: nf_tables: revisit chain/object refcounting from elements")
+Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
+Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
+---
+ include/net/netfilter/nf_tables.h | 5 +
+ net/netfilter/nf_tables_api.c | 145 +++++++++++++++++++++++++++++++++-----
+ net/netfilter/nft_set_bitmap.c | 5 -
+ net/netfilter/nft_set_hash.c | 23 ++++--
+ net/netfilter/nft_set_pipapo.c | 14 ++-
+ net/netfilter/nft_set_rbtree.c | 5 -
+ 6 files changed, 166 insertions(+), 31 deletions(-)
+
+--- a/include/net/netfilter/nf_tables.h
++++ b/include/net/netfilter/nf_tables.h
+@@ -437,7 +437,8 @@ struct nft_set_ops {
+ int (*init)(const struct nft_set *set,
+ const struct nft_set_desc *desc,
+ const struct nlattr * const nla[]);
+- void (*destroy)(const struct nft_set *set);
++ void (*destroy)(const struct nft_ctx *ctx,
++ const struct nft_set *set);
+ void (*gc_init)(const struct nft_set *set);
+
+ unsigned int elemsize;
+@@ -772,6 +773,8 @@ int nft_set_elem_expr_clone(const struct
+ struct nft_expr *expr_array[]);
+ void nft_set_elem_destroy(const struct nft_set *set, void *elem,
+ bool destroy_expr);
++void nf_tables_set_elem_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set, void *elem);
+
+ /**
+ * struct nft_set_gc_batch_head - nf_tables set garbage collection batch
+--- a/net/netfilter/nf_tables_api.c
++++ b/net/netfilter/nf_tables_api.c
+@@ -581,6 +581,58 @@ static int nft_trans_set_add(const struc
+ return __nft_trans_set_add(ctx, msg_type, set, NULL);
+ }
+
++static void nft_setelem_data_deactivate(const struct net *net,
++ const struct nft_set *set,
++ struct nft_set_elem *elem);
++
++static int nft_mapelem_deactivate(const struct nft_ctx *ctx,
++ struct nft_set *set,
++ const struct nft_set_iter *iter,
++ struct nft_set_elem *elem)
++{
++ nft_setelem_data_deactivate(ctx->net, set, elem);
++
++ return 0;
++}
++
++struct nft_set_elem_catchall {
++ struct list_head list;
++ struct rcu_head rcu;
++ void *elem;
++};
++
++static void nft_map_catchall_deactivate(const struct nft_ctx *ctx,
++ struct nft_set *set)
++{
++ u8 genmask = nft_genmask_next(ctx->net);
++ struct nft_set_elem_catchall *catchall;
++ struct nft_set_elem elem;
++ struct nft_set_ext *ext;
++
++ list_for_each_entry(catchall, &set->catchall_list, list) {
++ ext = nft_set_elem_ext(set, catchall->elem);
++ if (!nft_set_elem_active(ext, genmask))
++ continue;
++
++ elem.priv = catchall->elem;
++ nft_setelem_data_deactivate(ctx->net, set, &elem);
++ break;
++ }
++}
++
++static void nft_map_deactivate(const struct nft_ctx *ctx, struct nft_set *set)
++{
++ struct nft_set_iter iter = {
++ .genmask = nft_genmask_next(ctx->net),
++ .fn = nft_mapelem_deactivate,
++ };
++
++ set->ops->walk(ctx, set, &iter);
++ WARN_ON_ONCE(iter.err);
++
++ nft_map_catchall_deactivate(ctx, set);
++}
++
+ static int nft_delset(const struct nft_ctx *ctx, struct nft_set *set)
+ {
+ int err;
+@@ -589,6 +641,9 @@ static int nft_delset(const struct nft_c
+ if (err < 0)
+ return err;
+
++ if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_deactivate(ctx, set);
++
+ nft_deactivate_next(ctx->net, set);
+ ctx->table->use--;
+
+@@ -3407,12 +3462,6 @@ int nft_setelem_validate(const struct nf
+ return 0;
+ }
+
+-struct nft_set_elem_catchall {
+- struct list_head list;
+- struct rcu_head rcu;
+- void *elem;
+-};
+-
+ int nft_set_catchall_validate(const struct nft_ctx *ctx, struct nft_set *set)
+ {
+ u8 genmask = nft_genmask_next(ctx->net);
+@@ -4734,7 +4783,7 @@ err_set_expr_alloc:
+ for (i = 0; i < set->num_exprs; i++)
+ nft_expr_destroy(&ctx, set->exprs[i]);
+ err_set_destroy:
+- ops->destroy(set);
++ ops->destroy(&ctx, set);
+ err_set_init:
+ kfree(set->name);
+ err_set_name:
+@@ -4749,7 +4798,7 @@ static void nft_set_catchall_destroy(con
+
+ list_for_each_entry_safe(catchall, next, &set->catchall_list, list) {
+ list_del_rcu(&catchall->list);
+- nft_set_elem_destroy(set, catchall->elem, true);
++ nf_tables_set_elem_destroy(ctx, set, catchall->elem);
+ kfree_rcu(catchall, rcu);
+ }
+ }
+@@ -4764,7 +4813,7 @@ static void nft_set_destroy(const struct
+ for (i = 0; i < set->num_exprs; i++)
+ nft_expr_destroy(ctx, set->exprs[i]);
+
+- set->ops->destroy(set);
++ set->ops->destroy(ctx, set);
+ nft_set_catchall_destroy(ctx, set);
+ kfree(set->name);
+ kvfree(set);
+@@ -4925,10 +4974,60 @@ static void nf_tables_unbind_set(const s
+ }
+ }
+
++static void nft_setelem_data_activate(const struct net *net,
++ const struct nft_set *set,
++ struct nft_set_elem *elem);
++
++static int nft_mapelem_activate(const struct nft_ctx *ctx,
++ struct nft_set *set,
++ const struct nft_set_iter *iter,
++ struct nft_set_elem *elem)
++{
++ nft_setelem_data_activate(ctx->net, set, elem);
++
++ return 0;
++}
++
++static void nft_map_catchall_activate(const struct nft_ctx *ctx,
++ struct nft_set *set)
++{
++ u8 genmask = nft_genmask_next(ctx->net);
++ struct nft_set_elem_catchall *catchall;
++ struct nft_set_elem elem;
++ struct nft_set_ext *ext;
++
++ list_for_each_entry(catchall, &set->catchall_list, list) {
++ ext = nft_set_elem_ext(set, catchall->elem);
++ if (!nft_set_elem_active(ext, genmask))
++ continue;
++
++ elem.priv = catchall->elem;
++ nft_setelem_data_activate(ctx->net, set, &elem);
++ break;
++ }
++}
++
++static void nft_map_activate(const struct nft_ctx *ctx, struct nft_set *set)
++{
++ struct nft_set_iter iter = {
++ .genmask = nft_genmask_next(ctx->net),
++ .fn = nft_mapelem_activate,
++ };
++
++ set->ops->walk(ctx, set, &iter);
++ WARN_ON_ONCE(iter.err);
++
++ nft_map_catchall_activate(ctx, set);
++}
++
+ void nf_tables_activate_set(const struct nft_ctx *ctx, struct nft_set *set)
+ {
+- if (nft_set_is_anonymous(set))
++ if (nft_set_is_anonymous(set)) {
++ if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_activate(ctx, set);
++
+ nft_clear(ctx->net, set);
++ }
+
+ set->use++;
+ }
+@@ -4947,13 +5046,20 @@ void nf_tables_deactivate_set(const stru
+ set->use--;
+ break;
+ case NFT_TRANS_PREPARE:
+- if (nft_set_is_anonymous(set))
+- nft_deactivate_next(ctx->net, set);
++ if (nft_set_is_anonymous(set)) {
++ if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_deactivate(ctx, set);
+
++ nft_deactivate_next(ctx->net, set);
++ }
+ set->use--;
+ return;
+ case NFT_TRANS_ABORT:
+ case NFT_TRANS_RELEASE:
++ if (nft_set_is_anonymous(set) &&
++ set->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_deactivate(ctx, set);
++
+ set->use--;
+ fallthrough;
+ default:
+@@ -5669,6 +5775,7 @@ static void nft_set_elem_expr_destroy(co
+ __nft_set_elem_expr_destroy(ctx, expr);
+ }
+
++/* Drop references and destroy. Called from gc, dynset and abort path. */
+ void nft_set_elem_destroy(const struct nft_set *set, void *elem,
+ bool destroy_expr)
+ {
+@@ -5690,11 +5797,11 @@ void nft_set_elem_destroy(const struct n
+ }
+ EXPORT_SYMBOL_GPL(nft_set_elem_destroy);
+
+-/* Only called from commit path, nft_setelem_data_deactivate() already deals
+- * with the refcounting from the preparation phase.
++/* Destroy element. References have been already dropped in the preparation
++ * path via nft_setelem_data_deactivate().
+ */
+-static void nf_tables_set_elem_destroy(const struct nft_ctx *ctx,
+- const struct nft_set *set, void *elem)
++void nf_tables_set_elem_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set, void *elem)
+ {
+ struct nft_set_ext *ext = nft_set_elem_ext(set, elem);
+
+@@ -9323,6 +9430,9 @@ static int __nf_tables_abort(struct net
+ case NFT_MSG_DELSET:
+ trans->ctx.table->use++;
+ nft_clear(trans->ctx.net, nft_trans_set(trans));
++ if (nft_trans_set(trans)->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_activate(&trans->ctx, nft_trans_set(trans));
++
+ nft_trans_destroy(trans);
+ break;
+ case NFT_MSG_NEWSETELEM:
+@@ -10089,6 +10199,9 @@ static void __nft_release_table(struct n
+ list_for_each_entry_safe(set, ns, &table->sets, list) {
+ list_del(&set->list);
+ table->use--;
++ if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT))
++ nft_map_deactivate(&ctx, set);
++
+ nft_set_destroy(&ctx, set);
+ }
+ list_for_each_entry_safe(obj, ne, &table->objects, list) {
+--- a/net/netfilter/nft_set_bitmap.c
++++ b/net/netfilter/nft_set_bitmap.c
+@@ -271,13 +271,14 @@ static int nft_bitmap_init(const struct
+ return 0;
+ }
+
+-static void nft_bitmap_destroy(const struct nft_set *set)
++static void nft_bitmap_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set)
+ {
+ struct nft_bitmap *priv = nft_set_priv(set);
+ struct nft_bitmap_elem *be, *n;
+
+ list_for_each_entry_safe(be, n, &priv->list, head)
+- nft_set_elem_destroy(set, be, true);
++ nf_tables_set_elem_destroy(ctx, set, be);
+ }
+
+ static bool nft_bitmap_estimate(const struct nft_set_desc *desc, u32 features,
+--- a/net/netfilter/nft_set_hash.c
++++ b/net/netfilter/nft_set_hash.c
+@@ -400,19 +400,31 @@ static int nft_rhash_init(const struct n
+ return 0;
+ }
+
++struct nft_rhash_ctx {
++ const struct nft_ctx ctx;
++ const struct nft_set *set;
++};
++
+ static void nft_rhash_elem_destroy(void *ptr, void *arg)
+ {
+- nft_set_elem_destroy(arg, ptr, true);
++ struct nft_rhash_ctx *rhash_ctx = arg;
++
++ nf_tables_set_elem_destroy(&rhash_ctx->ctx, rhash_ctx->set, ptr);
+ }
+
+-static void nft_rhash_destroy(const struct nft_set *set)
++static void nft_rhash_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set)
+ {
+ struct nft_rhash *priv = nft_set_priv(set);
++ struct nft_rhash_ctx rhash_ctx = {
++ .ctx = *ctx,
++ .set = set,
++ };
+
+ cancel_delayed_work_sync(&priv->gc_work);
+ rcu_barrier();
+ rhashtable_free_and_destroy(&priv->ht, nft_rhash_elem_destroy,
+- (void *)set);
++ (void *)&rhash_ctx);
+ }
+
+ /* Number of buckets is stored in u32, so cap our result to 1U<<31 */
+@@ -643,7 +655,8 @@ static int nft_hash_init(const struct nf
+ return 0;
+ }
+
+-static void nft_hash_destroy(const struct nft_set *set)
++static void nft_hash_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set)
+ {
+ struct nft_hash *priv = nft_set_priv(set);
+ struct nft_hash_elem *he;
+@@ -653,7 +666,7 @@ static void nft_hash_destroy(const struc
+ for (i = 0; i < priv->buckets; i++) {
+ hlist_for_each_entry_safe(he, next, &priv->table[i], node) {
+ hlist_del_rcu(&he->node);
+- nft_set_elem_destroy(set, he, true);
++ nf_tables_set_elem_destroy(ctx, set, he);
+ }
+ }
+ }
+--- a/net/netfilter/nft_set_pipapo.c
++++ b/net/netfilter/nft_set_pipapo.c
+@@ -2152,10 +2152,12 @@ out_scratch:
+
+ /**
+ * nft_set_pipapo_match_destroy() - Destroy elements from key mapping array
++ * @ctx: context
+ * @set: nftables API set representation
+ * @m: matching data pointing to key mapping array
+ */
+-static void nft_set_pipapo_match_destroy(const struct nft_set *set,
++static void nft_set_pipapo_match_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set,
+ struct nft_pipapo_match *m)
+ {
+ struct nft_pipapo_field *f;
+@@ -2172,15 +2174,17 @@ static void nft_set_pipapo_match_destroy
+
+ e = f->mt[r].e;
+
+- nft_set_elem_destroy(set, e, true);
++ nf_tables_set_elem_destroy(ctx, set, e);
+ }
+ }
+
+ /**
+ * nft_pipapo_destroy() - Free private data for set and all committed elements
++ * @ctx: context
+ * @set: nftables API set representation
+ */
+-static void nft_pipapo_destroy(const struct nft_set *set)
++static void nft_pipapo_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set)
+ {
+ struct nft_pipapo *priv = nft_set_priv(set);
+ struct nft_pipapo_match *m;
+@@ -2190,7 +2194,7 @@ static void nft_pipapo_destroy(const str
+ if (m) {
+ rcu_barrier();
+
+- nft_set_pipapo_match_destroy(set, m);
++ nft_set_pipapo_match_destroy(ctx, set, m);
+
+ #ifdef NFT_PIPAPO_ALIGN
+ free_percpu(m->scratch_aligned);
+@@ -2207,7 +2211,7 @@ static void nft_pipapo_destroy(const str
+ m = priv->clone;
+
+ if (priv->dirty)
+- nft_set_pipapo_match_destroy(set, m);
++ nft_set_pipapo_match_destroy(ctx, set, m);
+
+ #ifdef NFT_PIPAPO_ALIGN
+ free_percpu(priv->clone->scratch_aligned);
+--- a/net/netfilter/nft_set_rbtree.c
++++ b/net/netfilter/nft_set_rbtree.c
+@@ -664,7 +664,8 @@ static int nft_rbtree_init(const struct
+ return 0;
+ }
+
+-static void nft_rbtree_destroy(const struct nft_set *set)
++static void nft_rbtree_destroy(const struct nft_ctx *ctx,
++ const struct nft_set *set)
+ {
+ struct nft_rbtree *priv = nft_set_priv(set);
+ struct nft_rbtree_elem *rbe;
+@@ -675,7 +676,7 @@ static void nft_rbtree_destroy(const str
+ while ((node = priv->root.rb_node) != NULL) {
+ rb_erase(node, &priv->root);
+ rbe = rb_entry(node, struct nft_rbtree_elem, node);
+- nft_set_elem_destroy(set, rbe, true);
++ nf_tables_set_elem_destroy(ctx, set, rbe);
+ }
+ }
+