]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
netfilter: cttimeout: detach dataplane timeout policy and repurpose refcount
authorPablo Neira Ayuso <pablo@netfilter.org>
Thu, 4 Jun 2026 06:21:08 +0000 (08:21 +0200)
committerPablo Neira Ayuso <pablo@netfilter.org>
Fri, 5 Jun 2026 11:11:55 +0000 (13:11 +0200)
Add a refcount for struct nf_ct_timeout which is used by ct extension to
set the custom ct timeout policy, this tells us that the ct timeout is
being used by a conntrack entry. When the last conntrack entry drops the
refcount on the ct timeout, the ct timeout is released.

Remove the refcount for control plane which controls if the ruleset
refers to the timeout policy. After this update, it is possible to
remove the ct timeout policy from nfnetlink_cttimeout immediately.
This is for simplicity not to handle two refcounts on a single object.

Remove nf_queue_nf_hook_drop(): a packet sitting in nfqueue will just
hold a reference to the nf_ct_timeout object until packet is reinjected,
since this is part of the ct extension, this will be released by the
time the conntrack is freed.

nf_ct_untimeout() is still called to clean up in a best effort basis:
the ct timeout on existing entries gets removed when the ct timeout goes
away, but as long as the iptables ruleset still refers to the ct timeout
through a template, new conntracks may keep attaching it and extend its
lifetime until the rule is removed.

nf_ct_untimeout() is not called anymore from module removal path, this
is unlikely to find timeouts give module refcount is bumped, and the new
refcount already tracks the ct timeout policy use so it is released when
unused.

Fixes: 50978462300f ("netfilter: add cttimeout infrastructure for fine timeout tuning")
Fixes: 7e0b2b57f01d ("netfilter: nft_ct: add ct timeout support")
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
include/net/netfilter/nf_conntrack_timeout.h
net/netfilter/nf_conntrack_core.c
net/netfilter/nf_conntrack_timeout.c
net/netfilter/nfnetlink_cttimeout.c
net/netfilter/nft_ct.c
net/netfilter/xt_CT.c

index 3a66d4abb6d6885019fc558063b50671ef4f16af..d60aa86be01945c8ae7ae808c6aef73c0d81eca2 100644 (file)
@@ -12,6 +12,7 @@
 #define CTNL_TIMEOUT_NAME_MAX  32
 
 struct nf_ct_timeout {
+       refcount_t              refcnt;
        __u16                   l3num;
        const struct nf_conntrack_l4proto *l4proto;
        struct rcu_head         rcu;
@@ -22,6 +23,22 @@ struct nf_conn_timeout {
        struct nf_ct_timeout __rcu *timeout;
 };
 
+static inline void nf_ct_timeout_put(const struct nf_conn *ct)
+{
+#ifdef CONFIG_NF_CONNTRACK_TIMEOUT
+       struct nf_conn_timeout *timeout_ext;
+       struct nf_ct_timeout *timeout;
+
+       timeout_ext = nf_ct_ext_find(ct, NF_CT_EXT_TIMEOUT);
+       if (!timeout_ext)
+               return;
+
+       timeout = rcu_dereference(timeout_ext->timeout);
+       if (timeout && refcount_dec_and_test(&timeout->refcnt))
+               kfree_rcu(timeout, rcu);
+#endif
+}
+
 static inline unsigned int *
 nf_ct_timeout_data(const struct nf_conn_timeout *t)
 {
@@ -56,8 +73,14 @@ struct nf_conn_timeout *nf_ct_timeout_ext_add(struct nf_conn *ct,
 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
        struct nf_conn_timeout *timeout_ext;
 
+       if (!timeout)
+               return NULL;
+
        timeout_ext = nf_ct_ext_add(ct, NF_CT_EXT_TIMEOUT, gfp);
-       if (timeout_ext == NULL)
+       if (!timeout_ext || timeout_ext->timeout)
+               return NULL;
+
+       if (!refcount_inc_not_zero(&timeout->refcnt))
                return NULL;
 
        rcu_assign_pointer(timeout_ext->timeout, timeout);
@@ -75,7 +98,7 @@ static inline unsigned int *nf_ct_timeout_lookup(const struct nf_conn *ct)
        struct nf_conn_timeout *timeout_ext;
 
        timeout_ext = nf_ct_timeout_find(ct);
-       if (timeout_ext)
+       if (timeout_ext && rcu_access_pointer(timeout_ext->timeout))
                timeouts = nf_ct_timeout_data(timeout_ext);
 #endif
        return timeouts;
index c072a14a306afeb0832c822b0a077e60203b35b7..a45b732393693a69542a2fbf78a56b854f92691e 100644 (file)
@@ -1737,16 +1737,18 @@ void nf_conntrack_free(struct nf_conn *ct)
         */
        WARN_ON(refcount_read(&ct->ct_general.use) != 0);
 
+       rcu_read_lock();
        if (ct->status & IPS_SRC_NAT_DONE) {
                const struct nf_nat_hook *nat_hook;
 
-               rcu_read_lock();
                nat_hook = rcu_dereference(nf_nat_hook);
                if (nat_hook)
                        nat_hook->remove_nat_bysrc(ct);
-               rcu_read_unlock();
        }
 
+       nf_ct_timeout_put(ct);
+       rcu_read_unlock();
+
        kfree(ct->ext);
        kmem_cache_free(nf_conntrack_cachep, ct);
        cnet = nf_ct_pernet(net);
index 0cc584d3dbb1d0ac8b4c1e846e41bfcd620091b5..c81becde2afa9ffb02d70c19b08012f76650dfd7 100644 (file)
 const struct nf_ct_timeout_hooks __rcu *nf_ct_timeout_hook __read_mostly;
 EXPORT_SYMBOL_GPL(nf_ct_timeout_hook);
 
+/* nf_ct_iterate_cleanup() holds the conntrack lock. */
 static int untimeout(struct nf_conn *ct, void *timeout)
 {
        struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct);
 
        if (timeout_ext) {
-               const struct nf_ct_timeout *t;
+               struct nf_ct_timeout *t;
 
-               t = rcu_access_pointer(timeout_ext->timeout);
+               rcu_read_lock();
+               t = rcu_dereference(timeout_ext->timeout);
+               if (!t) {
+                       rcu_read_unlock();
+                       return 0;
+               }
 
-               if (!timeout || t == timeout)
+               if (!timeout || t == timeout) {
                        RCU_INIT_POINTER(timeout_ext->timeout, NULL);
+
+                       /* No race with nf_conntrack_free() which is called
+                        * only after the conntrack has been removed from
+                        * the hashes.
+                        */
+                       if (refcount_dec_and_test(&t->refcnt))
+                               kfree_rcu(t, rcu);
+               }
+               rcu_read_unlock();
        }
 
        /* We are not intended to delete this conntrack. */
@@ -70,6 +85,8 @@ int nf_ct_set_timeout(struct net *net, struct nf_conn *ct,
        const char *errmsg = NULL;
        int ret = 0;
 
+       WARN_ON_ONCE(!nf_ct_is_template(ct));
+
        rcu_read_lock();
        h = rcu_dereference(nf_ct_timeout_hook);
        if (!h) {
@@ -127,6 +144,8 @@ void nf_ct_destroy_timeout(struct nf_conn *ct)
        struct nf_conn_timeout *timeout_ext;
        const struct nf_ct_timeout_hooks *h;
 
+       WARN_ON_ONCE(!nf_ct_is_template(ct));
+
        rcu_read_lock();
        h = rcu_dereference(nf_ct_timeout_hook);
 
@@ -139,6 +158,8 @@ void nf_ct_destroy_timeout(struct nf_conn *ct)
                        if (t)
                                h->timeout_put(t);
                        RCU_INIT_POINTER(timeout_ext->timeout, NULL);
+                       if (t && refcount_dec_and_test(&t->refcnt))
+                               kfree_rcu(t, rcu);
                }
        }
        rcu_read_unlock();
index dca6826af7de3fb8ee69bcc4f41f0165fb806fe2..170d3db860c564a84d2cbcd702530fa1b676524f 100644 (file)
@@ -37,11 +37,8 @@ struct ctnl_timeout {
        struct list_head        head;
        struct list_head        free_head;
        struct rcu_head         rcu_head;
-       refcount_t              refcnt;
        char                    name[CTNL_TIMEOUT_NAME_MAX];
-
-       /* must be at the end */
-       struct nf_ct_timeout    timeout;
+       struct nf_ct_timeout    *timeout;
 };
 
 struct nfct_timeout_pernet {
@@ -132,12 +129,12 @@ static int cttimeout_new_timeout(struct sk_buff *skb,
                        /* You cannot replace one timeout policy by another of
                         * different kind, sorry.
                         */
-                       if (matching->timeout.l3num != l3num ||
-                           matching->timeout.l4proto->l4proto != l4num)
+                       if (matching->timeout->l3num != l3num ||
+                           matching->timeout->l4proto->l4proto != l4num)
                                return -EINVAL;
 
-                       return ctnl_timeout_parse_policy(&matching->timeout.data,
-                                                        matching->timeout.l4proto,
+                       return ctnl_timeout_parse_policy(&matching->timeout->data,
+                                                        matching->timeout->l4proto,
                                                         info->net,
                                                         cda[CTA_TIMEOUT_DATA]);
                }
@@ -153,26 +150,35 @@ static int cttimeout_new_timeout(struct sk_buff *skb,
                goto err_proto_put;
        }
 
-       timeout = kzalloc(sizeof(struct ctnl_timeout) +
-                         l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
+       timeout = kzalloc(sizeof(*timeout), GFP_KERNEL);
        if (timeout == NULL) {
                ret = -ENOMEM;
                goto err_proto_put;
        }
 
-       ret = ctnl_timeout_parse_policy(&timeout->timeout.data, l4proto,
+       timeout->timeout = kzalloc(sizeof(*timeout->timeout) +
+                                  l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
+       if (!timeout->timeout) {
+               ret = -ENOMEM;
+               goto err;
+       }
+
+       ret = ctnl_timeout_parse_policy(&timeout->timeout->data, l4proto,
                                        info->net, cda[CTA_TIMEOUT_DATA]);
        if (ret < 0)
-               goto err;
+               goto err_free_timeout_policy;
 
        strcpy(timeout->name, nla_data(cda[CTA_TIMEOUT_NAME]));
-       timeout->timeout.l3num = l3num;
-       timeout->timeout.l4proto = l4proto;
-       refcount_set(&timeout->refcnt, 1);
+       timeout->timeout->l3num = l3num;
+       timeout->timeout->l4proto = l4proto;
+       refcount_set(&timeout->timeout->refcnt, 1);
        __module_get(THIS_MODULE);
        list_add_tail_rcu(&timeout->head, &pernet->nfct_timeout_list);
 
        return 0;
+
+err_free_timeout_policy:
+       kfree(timeout->timeout);
 err:
        kfree(timeout);
 err_proto_put:
@@ -185,7 +191,7 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type,
 {
        struct nlmsghdr *nlh;
        unsigned int flags = portid ? NLM_F_MULTI : 0;
-       const struct nf_conntrack_l4proto *l4proto = timeout->timeout.l4proto;
+       const struct nf_conntrack_l4proto *l4proto = timeout->timeout->l4proto;
        struct nlattr *nest_parms;
        int ret;
 
@@ -197,17 +203,17 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type,
 
        if (nla_put_string(skb, CTA_TIMEOUT_NAME, timeout->name) ||
            nla_put_be16(skb, CTA_TIMEOUT_L3PROTO,
-                        htons(timeout->timeout.l3num)) ||
+                        htons(timeout->timeout->l3num)) ||
            nla_put_u8(skb, CTA_TIMEOUT_L4PROTO, l4proto->l4proto) ||
            nla_put_be32(skb, CTA_TIMEOUT_USE,
-                        htonl(refcount_read(&timeout->refcnt))))
+                        htonl(refcount_read(&timeout->timeout->refcnt))))
                goto nla_put_failure;
 
        nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA);
        if (!nest_parms)
                goto nla_put_failure;
 
-       ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout.data);
+       ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout->data);
        if (ret < 0)
                goto nla_put_failure;
 
@@ -307,23 +313,17 @@ static int cttimeout_get_timeout(struct sk_buff *skb,
        return ret;
 }
 
-/* try to delete object, fail if it is still in use. */
-static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout)
+static void ctnl_timeout_del(struct net *net, struct ctnl_timeout *timeout)
 {
-       int ret = 0;
+       /* We are protected by nfnl mutex. */
+       list_del_rcu(&timeout->head);
+       nf_ct_untimeout(net, timeout->timeout);
 
-       /* We want to avoid races with ctnl_timeout_put. So only when the
-        * current refcnt is 1, we decrease it to 0.
-        */
-       if (refcount_dec_if_one(&timeout->refcnt)) {
-               /* We are protected by nfnl mutex. */
-               list_del_rcu(&timeout->head);
-               nf_ct_untimeout(net, &timeout->timeout);
-               kfree_rcu(timeout, rcu_head);
-       } else {
-               ret = -EBUSY;
-       }
-       return ret;
+       if (refcount_dec_and_test(&timeout->timeout->refcnt))
+               kfree_rcu(timeout->timeout, rcu);
+
+       kfree_rcu(timeout, rcu_head);
+       module_put(THIS_MODULE);
 }
 
 static int cttimeout_del_timeout(struct sk_buff *skb,
@@ -338,7 +338,7 @@ static int cttimeout_del_timeout(struct sk_buff *skb,
        if (!cda[CTA_TIMEOUT_NAME]) {
                list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_list,
                                         head)
-                       ctnl_timeout_try_del(info->net, cur);
+                       ctnl_timeout_del(info->net, cur);
 
                return 0;
        }
@@ -348,10 +348,8 @@ static int cttimeout_del_timeout(struct sk_buff *skb,
                if (strncmp(cur->name, name, CTNL_TIMEOUT_NAME_MAX) != 0)
                        continue;
 
-               ret = ctnl_timeout_try_del(info->net, cur);
-               if (ret < 0)
-                       return ret;
-
+               ctnl_timeout_del(info->net, cur);
+               ret = 0;
                break;
        }
        return ret;
@@ -511,24 +509,22 @@ static struct nf_ct_timeout *ctnl_timeout_find_get(struct net *net,
                if (strncmp(timeout->name, name, CTNL_TIMEOUT_NAME_MAX) != 0)
                        continue;
 
-               if (!refcount_inc_not_zero(&timeout->refcnt))
+               if (!refcount_inc_not_zero(&timeout->timeout->refcnt))
                        goto err;
                matching = timeout;
+               __module_get(THIS_MODULE);
                break;
        }
 err:
-       return matching ? &matching->timeout : NULL;
+       return matching ? matching->timeout : NULL;
 }
 
-static void ctnl_timeout_put(struct nf_ct_timeout *t)
+static void ctnl_timeout_put(struct nf_ct_timeout *timeout)
 {
-       struct ctnl_timeout *timeout =
-               container_of(t, struct ctnl_timeout, timeout);
+       if (refcount_dec_and_test(&timeout->refcnt))
+               kfree_rcu(timeout, rcu);
 
-       if (refcount_dec_and_test(&timeout->refcnt)) {
-               kfree_rcu(timeout, rcu_head);
-               module_put(THIS_MODULE);
-       }
+       module_put(THIS_MODULE);
 }
 
 static const struct nfnl_callback cttimeout_cb[IPCTNL_MSG_TIMEOUT_MAX] = {
@@ -609,8 +605,11 @@ static void __net_exit cttimeout_net_exit(struct net *net)
        list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_freelist, free_head) {
                list_del(&cur->free_head);
 
-               if (refcount_dec_and_test(&cur->refcnt))
-                       kfree_rcu(cur, rcu_head);
+               if (refcount_dec_and_test(&cur->timeout->refcnt))
+                       kfree_rcu(cur->timeout, rcu);
+
+               kfree_rcu(cur, rcu_head);
+               module_put(THIS_MODULE);
        }
 }
 
@@ -649,24 +648,13 @@ err_out:
        return ret;
 }
 
-static int untimeout(struct nf_conn *ct, void *timeout)
-{
-       struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct);
-
-       if (timeout_ext)
-               RCU_INIT_POINTER(timeout_ext->timeout, NULL);
-
-       return 0;
-}
-
 static void __exit cttimeout_exit(void)
 {
        nfnetlink_subsys_unregister(&cttimeout_subsys);
 
        unregister_pernet_subsys(&cttimeout_ops);
        RCU_INIT_POINTER(nf_ct_timeout_hook, NULL);
-
-       nf_ct_iterate_destroy(untimeout, NULL);
+       synchronize_net();
 }
 
 module_init(cttimeout_init);
index 357513c6dcea08d91c4b498af6589b22a398ab5d..801c01c6af95fc4f6b53b7417cfb20f777cbbbd0 100644 (file)
@@ -897,8 +897,6 @@ static void nft_ct_timeout_obj_eval(struct nft_object *obj,
                }
        }
 
-       rcu_assign_pointer(timeout->timeout, priv->timeout);
-
        /* adjust the timeout as per 'new' state. ct is unconfirmed,
         * so the current timestamp must not be added.
         */
@@ -949,6 +947,7 @@ static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
 
        timeout->l3num = l3num;
        timeout->l4proto = l4proto;
+       refcount_set(&timeout->refcnt, 1);
 
        ret = nf_ct_netns_get(ctx->net, ctx->family);
        if (ret < 0)
@@ -969,10 +968,10 @@ static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
        struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
        struct nf_ct_timeout *timeout = priv->timeout;
 
-       nf_queue_nf_hook_drop(ctx->net);
        nf_ct_untimeout(ctx->net, timeout);
        nf_ct_netns_put(ctx->net, ctx->family);
-       kfree_rcu(priv->timeout, rcu);
+       if (refcount_dec_and_test(&timeout->refcnt))
+               kfree_rcu(priv->timeout, rcu);
 }
 
 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
index d2aeacf94230f8970ce86ed14e548a320ebb3df9..b94f004d5f5c2737f84dfe8df5fa493f285fe5aa 100644 (file)
@@ -284,7 +284,7 @@ static void xt_ct_tg_destroy(const struct xt_tgdtor_param *par,
        struct nf_conn_help *help;
 
        if (ct) {
-               if (info->helper[0] || info->timeout[0])
+               if (info->helper[0])
                        nf_queue_nf_hook_drop(par->net);
 
                help = nfct_help(ct);