From: Pablo Neira Ayuso Date: Thu, 4 Jun 2026 06:21:08 +0000 (+0200) Subject: netfilter: cttimeout: detach dataplane timeout policy and repurpose refcount X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d6a9cdb8d3a51d9cfe546a09a518ab3d2671549;p=thirdparty%2Fkernel%2Flinux.git netfilter: cttimeout: detach dataplane timeout policy and repurpose refcount Add a refcount for struct nf_ct_timeout which is used by ct extension to set the custom ct timeout policy, this tells us that the ct timeout is being used by a conntrack entry. When the last conntrack entry drops the refcount on the ct timeout, the ct timeout is released. Remove the refcount for control plane which controls if the ruleset refers to the timeout policy. After this update, it is possible to remove the ct timeout policy from nfnetlink_cttimeout immediately. This is for simplicity not to handle two refcounts on a single object. Remove nf_queue_nf_hook_drop(): a packet sitting in nfqueue will just hold a reference to the nf_ct_timeout object until packet is reinjected, since this is part of the ct extension, this will be released by the time the conntrack is freed. nf_ct_untimeout() is still called to clean up in a best effort basis: the ct timeout on existing entries gets removed when the ct timeout goes away, but as long as the iptables ruleset still refers to the ct timeout through a template, new conntracks may keep attaching it and extend its lifetime until the rule is removed. nf_ct_untimeout() is not called anymore from module removal path, this is unlikely to find timeouts give module refcount is bumped, and the new refcount already tracks the ct timeout policy use so it is released when unused. Fixes: 50978462300f ("netfilter: add cttimeout infrastructure for fine timeout tuning") Fixes: 7e0b2b57f01d ("netfilter: nft_ct: add ct timeout support") Signed-off-by: Pablo Neira Ayuso --- diff --git a/include/net/netfilter/nf_conntrack_timeout.h b/include/net/netfilter/nf_conntrack_timeout.h index 3a66d4abb6d68..d60aa86be0194 100644 --- a/include/net/netfilter/nf_conntrack_timeout.h +++ b/include/net/netfilter/nf_conntrack_timeout.h @@ -12,6 +12,7 @@ #define CTNL_TIMEOUT_NAME_MAX 32 struct nf_ct_timeout { + refcount_t refcnt; __u16 l3num; const struct nf_conntrack_l4proto *l4proto; struct rcu_head rcu; @@ -22,6 +23,22 @@ struct nf_conn_timeout { struct nf_ct_timeout __rcu *timeout; }; +static inline void nf_ct_timeout_put(const struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + struct nf_conn_timeout *timeout_ext; + struct nf_ct_timeout *timeout; + + timeout_ext = nf_ct_ext_find(ct, NF_CT_EXT_TIMEOUT); + if (!timeout_ext) + return; + + timeout = rcu_dereference(timeout_ext->timeout); + if (timeout && refcount_dec_and_test(&timeout->refcnt)) + kfree_rcu(timeout, rcu); +#endif +} + static inline unsigned int * nf_ct_timeout_data(const struct nf_conn_timeout *t) { @@ -56,8 +73,14 @@ struct nf_conn_timeout *nf_ct_timeout_ext_add(struct nf_conn *ct, #ifdef CONFIG_NF_CONNTRACK_TIMEOUT struct nf_conn_timeout *timeout_ext; + if (!timeout) + return NULL; + timeout_ext = nf_ct_ext_add(ct, NF_CT_EXT_TIMEOUT, gfp); - if (timeout_ext == NULL) + if (!timeout_ext || timeout_ext->timeout) + return NULL; + + if (!refcount_inc_not_zero(&timeout->refcnt)) return NULL; rcu_assign_pointer(timeout_ext->timeout, timeout); @@ -75,7 +98,7 @@ static inline unsigned int *nf_ct_timeout_lookup(const struct nf_conn *ct) struct nf_conn_timeout *timeout_ext; timeout_ext = nf_ct_timeout_find(ct); - if (timeout_ext) + if (timeout_ext && rcu_access_pointer(timeout_ext->timeout)) timeouts = nf_ct_timeout_data(timeout_ext); #endif return timeouts; diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index c072a14a306af..a45b732393693 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -1737,16 +1737,18 @@ void nf_conntrack_free(struct nf_conn *ct) */ WARN_ON(refcount_read(&ct->ct_general.use) != 0); + rcu_read_lock(); if (ct->status & IPS_SRC_NAT_DONE) { const struct nf_nat_hook *nat_hook; - rcu_read_lock(); nat_hook = rcu_dereference(nf_nat_hook); if (nat_hook) nat_hook->remove_nat_bysrc(ct); - rcu_read_unlock(); } + nf_ct_timeout_put(ct); + rcu_read_unlock(); + kfree(ct->ext); kmem_cache_free(nf_conntrack_cachep, ct); cnet = nf_ct_pernet(net); diff --git a/net/netfilter/nf_conntrack_timeout.c b/net/netfilter/nf_conntrack_timeout.c index 0cc584d3dbb1d..c81becde2afa9 100644 --- a/net/netfilter/nf_conntrack_timeout.c +++ b/net/netfilter/nf_conntrack_timeout.c @@ -25,17 +25,32 @@ const struct nf_ct_timeout_hooks __rcu *nf_ct_timeout_hook __read_mostly; EXPORT_SYMBOL_GPL(nf_ct_timeout_hook); +/* nf_ct_iterate_cleanup() holds the conntrack lock. */ static int untimeout(struct nf_conn *ct, void *timeout) { struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct); if (timeout_ext) { - const struct nf_ct_timeout *t; + struct nf_ct_timeout *t; - t = rcu_access_pointer(timeout_ext->timeout); + rcu_read_lock(); + t = rcu_dereference(timeout_ext->timeout); + if (!t) { + rcu_read_unlock(); + return 0; + } - if (!timeout || t == timeout) + if (!timeout || t == timeout) { RCU_INIT_POINTER(timeout_ext->timeout, NULL); + + /* No race with nf_conntrack_free() which is called + * only after the conntrack has been removed from + * the hashes. + */ + if (refcount_dec_and_test(&t->refcnt)) + kfree_rcu(t, rcu); + } + rcu_read_unlock(); } /* We are not intended to delete this conntrack. */ @@ -70,6 +85,8 @@ int nf_ct_set_timeout(struct net *net, struct nf_conn *ct, const char *errmsg = NULL; int ret = 0; + WARN_ON_ONCE(!nf_ct_is_template(ct)); + rcu_read_lock(); h = rcu_dereference(nf_ct_timeout_hook); if (!h) { @@ -127,6 +144,8 @@ void nf_ct_destroy_timeout(struct nf_conn *ct) struct nf_conn_timeout *timeout_ext; const struct nf_ct_timeout_hooks *h; + WARN_ON_ONCE(!nf_ct_is_template(ct)); + rcu_read_lock(); h = rcu_dereference(nf_ct_timeout_hook); @@ -139,6 +158,8 @@ void nf_ct_destroy_timeout(struct nf_conn *ct) if (t) h->timeout_put(t); RCU_INIT_POINTER(timeout_ext->timeout, NULL); + if (t && refcount_dec_and_test(&t->refcnt)) + kfree_rcu(t, rcu); } } rcu_read_unlock(); diff --git a/net/netfilter/nfnetlink_cttimeout.c b/net/netfilter/nfnetlink_cttimeout.c index dca6826af7de3..170d3db860c56 100644 --- a/net/netfilter/nfnetlink_cttimeout.c +++ b/net/netfilter/nfnetlink_cttimeout.c @@ -37,11 +37,8 @@ struct ctnl_timeout { struct list_head head; struct list_head free_head; struct rcu_head rcu_head; - refcount_t refcnt; char name[CTNL_TIMEOUT_NAME_MAX]; - - /* must be at the end */ - struct nf_ct_timeout timeout; + struct nf_ct_timeout *timeout; }; struct nfct_timeout_pernet { @@ -132,12 +129,12 @@ static int cttimeout_new_timeout(struct sk_buff *skb, /* You cannot replace one timeout policy by another of * different kind, sorry. */ - if (matching->timeout.l3num != l3num || - matching->timeout.l4proto->l4proto != l4num) + if (matching->timeout->l3num != l3num || + matching->timeout->l4proto->l4proto != l4num) return -EINVAL; - return ctnl_timeout_parse_policy(&matching->timeout.data, - matching->timeout.l4proto, + return ctnl_timeout_parse_policy(&matching->timeout->data, + matching->timeout->l4proto, info->net, cda[CTA_TIMEOUT_DATA]); } @@ -153,26 +150,35 @@ static int cttimeout_new_timeout(struct sk_buff *skb, goto err_proto_put; } - timeout = kzalloc(sizeof(struct ctnl_timeout) + - l4proto->ctnl_timeout.obj_size, GFP_KERNEL); + timeout = kzalloc(sizeof(*timeout), GFP_KERNEL); if (timeout == NULL) { ret = -ENOMEM; goto err_proto_put; } - ret = ctnl_timeout_parse_policy(&timeout->timeout.data, l4proto, + timeout->timeout = kzalloc(sizeof(*timeout->timeout) + + l4proto->ctnl_timeout.obj_size, GFP_KERNEL); + if (!timeout->timeout) { + ret = -ENOMEM; + goto err; + } + + ret = ctnl_timeout_parse_policy(&timeout->timeout->data, l4proto, info->net, cda[CTA_TIMEOUT_DATA]); if (ret < 0) - goto err; + goto err_free_timeout_policy; strcpy(timeout->name, nla_data(cda[CTA_TIMEOUT_NAME])); - timeout->timeout.l3num = l3num; - timeout->timeout.l4proto = l4proto; - refcount_set(&timeout->refcnt, 1); + timeout->timeout->l3num = l3num; + timeout->timeout->l4proto = l4proto; + refcount_set(&timeout->timeout->refcnt, 1); __module_get(THIS_MODULE); list_add_tail_rcu(&timeout->head, &pernet->nfct_timeout_list); return 0; + +err_free_timeout_policy: + kfree(timeout->timeout); err: kfree(timeout); err_proto_put: @@ -185,7 +191,7 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, { struct nlmsghdr *nlh; unsigned int flags = portid ? NLM_F_MULTI : 0; - const struct nf_conntrack_l4proto *l4proto = timeout->timeout.l4proto; + const struct nf_conntrack_l4proto *l4proto = timeout->timeout->l4proto; struct nlattr *nest_parms; int ret; @@ -197,17 +203,17 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, if (nla_put_string(skb, CTA_TIMEOUT_NAME, timeout->name) || nla_put_be16(skb, CTA_TIMEOUT_L3PROTO, - htons(timeout->timeout.l3num)) || + htons(timeout->timeout->l3num)) || nla_put_u8(skb, CTA_TIMEOUT_L4PROTO, l4proto->l4proto) || nla_put_be32(skb, CTA_TIMEOUT_USE, - htonl(refcount_read(&timeout->refcnt)))) + htonl(refcount_read(&timeout->timeout->refcnt)))) goto nla_put_failure; nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA); if (!nest_parms) goto nla_put_failure; - ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout.data); + ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout->data); if (ret < 0) goto nla_put_failure; @@ -307,23 +313,17 @@ static int cttimeout_get_timeout(struct sk_buff *skb, return ret; } -/* try to delete object, fail if it is still in use. */ -static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout) +static void ctnl_timeout_del(struct net *net, struct ctnl_timeout *timeout) { - int ret = 0; + /* We are protected by nfnl mutex. */ + list_del_rcu(&timeout->head); + nf_ct_untimeout(net, timeout->timeout); - /* We want to avoid races with ctnl_timeout_put. So only when the - * current refcnt is 1, we decrease it to 0. - */ - if (refcount_dec_if_one(&timeout->refcnt)) { - /* We are protected by nfnl mutex. */ - list_del_rcu(&timeout->head); - nf_ct_untimeout(net, &timeout->timeout); - kfree_rcu(timeout, rcu_head); - } else { - ret = -EBUSY; - } - return ret; + if (refcount_dec_and_test(&timeout->timeout->refcnt)) + kfree_rcu(timeout->timeout, rcu); + + kfree_rcu(timeout, rcu_head); + module_put(THIS_MODULE); } static int cttimeout_del_timeout(struct sk_buff *skb, @@ -338,7 +338,7 @@ static int cttimeout_del_timeout(struct sk_buff *skb, if (!cda[CTA_TIMEOUT_NAME]) { list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_list, head) - ctnl_timeout_try_del(info->net, cur); + ctnl_timeout_del(info->net, cur); return 0; } @@ -348,10 +348,8 @@ static int cttimeout_del_timeout(struct sk_buff *skb, if (strncmp(cur->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) continue; - ret = ctnl_timeout_try_del(info->net, cur); - if (ret < 0) - return ret; - + ctnl_timeout_del(info->net, cur); + ret = 0; break; } return ret; @@ -511,24 +509,22 @@ static struct nf_ct_timeout *ctnl_timeout_find_get(struct net *net, if (strncmp(timeout->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) continue; - if (!refcount_inc_not_zero(&timeout->refcnt)) + if (!refcount_inc_not_zero(&timeout->timeout->refcnt)) goto err; matching = timeout; + __module_get(THIS_MODULE); break; } err: - return matching ? &matching->timeout : NULL; + return matching ? matching->timeout : NULL; } -static void ctnl_timeout_put(struct nf_ct_timeout *t) +static void ctnl_timeout_put(struct nf_ct_timeout *timeout) { - struct ctnl_timeout *timeout = - container_of(t, struct ctnl_timeout, timeout); + if (refcount_dec_and_test(&timeout->refcnt)) + kfree_rcu(timeout, rcu); - if (refcount_dec_and_test(&timeout->refcnt)) { - kfree_rcu(timeout, rcu_head); - module_put(THIS_MODULE); - } + module_put(THIS_MODULE); } static const struct nfnl_callback cttimeout_cb[IPCTNL_MSG_TIMEOUT_MAX] = { @@ -609,8 +605,11 @@ static void __net_exit cttimeout_net_exit(struct net *net) list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_freelist, free_head) { list_del(&cur->free_head); - if (refcount_dec_and_test(&cur->refcnt)) - kfree_rcu(cur, rcu_head); + if (refcount_dec_and_test(&cur->timeout->refcnt)) + kfree_rcu(cur->timeout, rcu); + + kfree_rcu(cur, rcu_head); + module_put(THIS_MODULE); } } @@ -649,24 +648,13 @@ err_out: return ret; } -static int untimeout(struct nf_conn *ct, void *timeout) -{ - struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct); - - if (timeout_ext) - RCU_INIT_POINTER(timeout_ext->timeout, NULL); - - return 0; -} - static void __exit cttimeout_exit(void) { nfnetlink_subsys_unregister(&cttimeout_subsys); unregister_pernet_subsys(&cttimeout_ops); RCU_INIT_POINTER(nf_ct_timeout_hook, NULL); - - nf_ct_iterate_destroy(untimeout, NULL); + synchronize_net(); } module_init(cttimeout_init); diff --git a/net/netfilter/nft_ct.c b/net/netfilter/nft_ct.c index 357513c6dcea0..801c01c6af95f 100644 --- a/net/netfilter/nft_ct.c +++ b/net/netfilter/nft_ct.c @@ -897,8 +897,6 @@ static void nft_ct_timeout_obj_eval(struct nft_object *obj, } } - rcu_assign_pointer(timeout->timeout, priv->timeout); - /* adjust the timeout as per 'new' state. ct is unconfirmed, * so the current timestamp must not be added. */ @@ -949,6 +947,7 @@ static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx, timeout->l3num = l3num; timeout->l4proto = l4proto; + refcount_set(&timeout->refcnt, 1); ret = nf_ct_netns_get(ctx->net, ctx->family); if (ret < 0) @@ -969,10 +968,10 @@ static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx, struct nft_ct_timeout_obj *priv = nft_obj_data(obj); struct nf_ct_timeout *timeout = priv->timeout; - nf_queue_nf_hook_drop(ctx->net); nf_ct_untimeout(ctx->net, timeout); nf_ct_netns_put(ctx->net, ctx->family); - kfree_rcu(priv->timeout, rcu); + if (refcount_dec_and_test(&timeout->refcnt)) + kfree_rcu(priv->timeout, rcu); } static int nft_ct_timeout_obj_dump(struct sk_buff *skb, diff --git a/net/netfilter/xt_CT.c b/net/netfilter/xt_CT.c index d2aeacf94230f..b94f004d5f5c2 100644 --- a/net/netfilter/xt_CT.c +++ b/net/netfilter/xt_CT.c @@ -284,7 +284,7 @@ static void xt_ct_tg_destroy(const struct xt_tgdtor_param *par, struct nf_conn_help *help; if (ct) { - if (info->helper[0] || info->timeout[0]) + if (info->helper[0]) nf_queue_nf_hook_drop(par->net); help = nfct_help(ct);