]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
ipmr: Free mr_table after RCU grace period.
authorKuniyuki Iwashima <kuniyu@google.com>
Thu, 23 Apr 2026 05:34:54 +0000 (05:34 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 28 Apr 2026 01:46:17 +0000 (18:46 -0700)
With CONFIG_IP_MROUTE_MULTIPLE_TABLES=n, ipmr_fib_lookup()
does not check if net->ipv4.mrt is NULL.

Since default_device_exit_batch() is called after ->exit_rtnl(),
a device could receive IGMP packets and access net->ipv4.mrt
during/after ipmr_rules_exit_rtnl().

If ipmr_rules_exit_rtnl() had already cleared it and freed the
memory, the access would trigger null-ptr-deref or use-after-free.

Let's fix it by using RCU helper and free mrt after RCU grace
period.

In addition, check_net(net) is added to mroute_clean_tables()
and ipmr_cache_unresolved() to synchronise via mfc_unres_lock.
This prevents ipmr_cache_unresolved() from putting skb into
c->_c.mfc_un.unres.unresolved after mroute_clean_tables()
purges it.

For the same reason, timer_shutdown_sync() is moved after
mroute_clean_tables().

Since rhltable_destroy() holds mutex internally, rcu_work is
used, and it is placed as the first member because rcu_head
must be placed within <4K offset.  mr_table is alraedy 3864
bytes without rcu_work.

Note that IP6MR is not yet converted to ->exit_rtnl(), so this
change is not needed for now but will be.

Fixes: b22b01867406 ("ipmr: Convert ipmr_net_exit_batch() to ->exit_rtnl().")
Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Link: https://patch.msgid.link/20260423053456.4097409-1-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/linux/mroute_base.h
net/ipv4/ipmr.c
net/ipv4/ipmr_base.c

index cf3374580f744dadc731f5f92da0518efc26c9be..5d75cc5b057eeead07111d37907833b999fb1ee0 100644 (file)
@@ -226,6 +226,7 @@ struct mr_table_ops {
 
 /**
  * struct mr_table - a multicast routing table
+ * @work: used for table destruction
  * @list: entry within a list of multicast routing tables
  * @net: net where this table belongs
  * @ops: protocol specific operations
@@ -243,6 +244,7 @@ struct mr_table_ops {
  * @mroute_reg_vif_num: PIM-device vif index
  */
 struct mr_table {
+       struct rcu_work         work;
        struct list_head        list;
        possible_net_t          net;
        struct mr_table_ops     ops;
@@ -274,6 +276,7 @@ void vif_device_init(struct vif_device *v,
                     unsigned short flags,
                     unsigned short get_iflink_mask);
 
+void mr_table_free(struct mr_table *mrt);
 struct mr_table *
 mr_table_alloc(struct net *net, u32 id,
               struct mr_table_ops *ops,
index 8a08d09b4c309b15ec05c395aadd7defacf404a7..2058ca860294b01385063555d0354b7a9a736118 100644 (file)
@@ -151,16 +151,6 @@ static struct mr_table *__ipmr_get_table(struct net *net, u32 id)
        return NULL;
 }
 
-static struct mr_table *ipmr_get_table(struct net *net, u32 id)
-{
-       struct mr_table *mrt;
-
-       rcu_read_lock();
-       mrt = __ipmr_get_table(net, id);
-       rcu_read_unlock();
-       return mrt;
-}
-
 static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
                           struct mr_table **mrt)
 {
@@ -293,7 +283,7 @@ static void __net_exit ipmr_rules_exit_rtnl(struct net *net,
        struct mr_table *mrt, *next;
 
        list_for_each_entry_safe(mrt, next, &net->ipv4.mr_tables, list) {
-               list_del(&mrt->list);
+               list_del_rcu(&mrt->list);
                ipmr_free_table(mrt, dev_kill_list);
        }
 }
@@ -315,28 +305,30 @@ bool ipmr_rule_default(const struct fib_rule *rule)
 }
 EXPORT_SYMBOL(ipmr_rule_default);
 #else
-#define ipmr_for_each_table(mrt, net) \
-       for (mrt = net->ipv4.mrt; mrt; mrt = NULL)
-
 static struct mr_table *ipmr_mr_table_iter(struct net *net,
                                           struct mr_table *mrt)
 {
        if (!mrt)
-               return net->ipv4.mrt;
+               return rcu_dereference(net->ipv4.mrt);
        return NULL;
 }
 
-static struct mr_table *ipmr_get_table(struct net *net, u32 id)
+static struct mr_table *__ipmr_get_table(struct net *net, u32 id)
 {
-       return net->ipv4.mrt;
+       return rcu_dereference_check(net->ipv4.mrt,
+                                    lockdep_rtnl_is_held() ||
+                                    !rcu_access_pointer(net->ipv4.mrt));
 }
 
-#define __ipmr_get_table ipmr_get_table
+#define ipmr_for_each_table(mrt, net)                          \
+       for (mrt = __ipmr_get_table(net, 0); mrt; mrt = NULL)
 
 static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
                           struct mr_table **mrt)
 {
-       *mrt = net->ipv4.mrt;
+       *mrt = rcu_dereference(net->ipv4.mrt);
+       if (!*mrt)
+               return -EAGAIN;
        return 0;
 }
 
@@ -347,7 +339,8 @@ static int __net_init ipmr_rules_init(struct net *net)
        mrt = ipmr_new_table(net, RT_TABLE_DEFAULT);
        if (IS_ERR(mrt))
                return PTR_ERR(mrt);
-       net->ipv4.mrt = mrt;
+
+       rcu_assign_pointer(net->ipv4.mrt, mrt);
        return 0;
 }
 
@@ -358,9 +351,10 @@ static void __net_exit ipmr_rules_exit(struct net *net)
 static void __net_exit ipmr_rules_exit_rtnl(struct net *net,
                                            struct list_head *dev_kill_list)
 {
-       ipmr_free_table(net->ipv4.mrt, dev_kill_list);
+       struct mr_table *mrt = rcu_dereference_protected(net->ipv4.mrt, 1);
 
-       net->ipv4.mrt = NULL;
+       RCU_INIT_POINTER(net->ipv4.mrt, NULL);
+       ipmr_free_table(mrt, dev_kill_list);
 }
 
 static int ipmr_rules_dump(struct net *net, struct notifier_block *nb,
@@ -381,6 +375,17 @@ bool ipmr_rule_default(const struct fib_rule *rule)
 EXPORT_SYMBOL(ipmr_rule_default);
 #endif
 
+static struct mr_table *ipmr_get_table(struct net *net, u32 id)
+{
+       struct mr_table *mrt;
+
+       rcu_read_lock();
+       mrt = __ipmr_get_table(net, id);
+       rcu_read_unlock();
+
+       return mrt;
+}
+
 static inline int ipmr_hash_cmp(struct rhashtable_compare_arg *arg,
                                const void *ptr)
 {
@@ -441,12 +446,11 @@ static void ipmr_free_table(struct mr_table *mrt, struct list_head *dev_kill_lis
 
        WARN_ON_ONCE(!mr_can_free_table(net));
 
-       timer_shutdown_sync(&mrt->ipmr_expire_timer);
        mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC |
                            MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC,
                            &ipmr_dev_kill_list);
-       rhltable_destroy(&mrt->mfc_hash);
-       kfree(mrt);
+       timer_shutdown_sync(&mrt->ipmr_expire_timer);
+       mr_table_free(mrt);
 
        WARN_ON_ONCE(!net_initialized(net) && !list_empty(&ipmr_dev_kill_list));
        list_splice(&ipmr_dev_kill_list, dev_kill_list);
@@ -1135,12 +1139,19 @@ static int ipmr_cache_report(const struct mr_table *mrt,
 static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
                                 struct sk_buff *skb, struct net_device *dev)
 {
+       struct net *net = read_pnet(&mrt->net);
        const struct iphdr *iph = ip_hdr(skb);
-       struct mfc_cache *c;
+       struct mfc_cache *c = NULL;
        bool found = false;
        int err;
 
        spin_lock_bh(&mfc_unres_lock);
+
+       if (!check_net(net)) {
+               err = -EINVAL;
+               goto err;
+       }
+
        list_for_each_entry(c, &mrt->mfc_unres_queue, _c.list) {
                if (c->mfc_mcastgrp == iph->daddr &&
                    c->mfc_origin == iph->saddr) {
@@ -1153,10 +1164,8 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
                /* Create a new entry if allowable */
                c = ipmr_cache_alloc_unres();
                if (!c) {
-                       spin_unlock_bh(&mfc_unres_lock);
-
-                       kfree_skb(skb);
-                       return -ENOBUFS;
+                       err = -ENOBUFS;
+                       goto err;
                }
 
                /* Fill in the new cache entry */
@@ -1166,17 +1175,8 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
 
                /* Reflect first query at mrouted. */
                err = ipmr_cache_report(mrt, skb, vifi, IGMPMSG_NOCACHE);
-
-               if (err < 0) {
-                       /* If the report failed throw the cache entry
-                          out - Brad Parker
-                        */
-                       spin_unlock_bh(&mfc_unres_lock);
-
-                       ipmr_cache_free(c);
-                       kfree_skb(skb);
-                       return err;
-               }
+               if (err < 0)
+                       goto err;
 
                atomic_inc(&mrt->cache_resolve_queue_len);
                list_add(&c->_c.list, &mrt->mfc_unres_queue);
@@ -1189,18 +1189,26 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
 
        /* See if we can append the packet */
        if (c->_c.mfc_un.unres.unresolved.qlen > 3) {
-               kfree_skb(skb);
+               c = NULL;
                err = -ENOBUFS;
-       } else {
-               if (dev) {
-                       skb->dev = dev;
-                       skb->skb_iif = dev->ifindex;
-               }
-               skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb);
-               err = 0;
+               goto err;
+       }
+
+       if (dev) {
+               skb->dev = dev;
+               skb->skb_iif = dev->ifindex;
        }
 
+       skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb);
+
        spin_unlock_bh(&mfc_unres_lock);
+       return 0;
+
+err:
+       spin_unlock_bh(&mfc_unres_lock);
+       if (c)
+               ipmr_cache_free(c);
+       kfree_skb(skb);
        return err;
 }
 
@@ -1346,7 +1354,7 @@ static void mroute_clean_tables(struct mr_table *mrt, int flags,
        }
 
        if (flags & MRT_FLUSH_MFC) {
-               if (atomic_read(&mrt->cache_resolve_queue_len) != 0) {
+               if (atomic_read(&mrt->cache_resolve_queue_len) != 0 || !check_net(net)) {
                        spin_lock_bh(&mfc_unres_lock);
                        list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) {
                                list_del(&c->list);
index 37a3c144276c75d56d5d295a2c852f5ffd262f65..3930d612c3deec21592f3285b447f7912882888c 100644 (file)
@@ -28,6 +28,20 @@ void vif_device_init(struct vif_device *v,
                v->link = dev->ifindex;
 }
 
+static void __mr_free_table(struct work_struct *work)
+{
+       struct mr_table *mrt = container_of(to_rcu_work(work),
+                                           struct mr_table, work);
+
+       rhltable_destroy(&mrt->mfc_hash);
+       kfree(mrt);
+}
+
+void mr_table_free(struct mr_table *mrt)
+{
+       queue_rcu_work(system_unbound_wq, &mrt->work);
+}
+
 struct mr_table *
 mr_table_alloc(struct net *net, u32 id,
               struct mr_table_ops *ops,
@@ -50,6 +64,8 @@ mr_table_alloc(struct net *net, u32 id,
                kfree(mrt);
                return ERR_PTR(err);
        }
+
+       INIT_RCU_WORK(&mrt->work, __mr_free_table);
        INIT_LIST_HEAD(&mrt->mfc_cache_list);
        INIT_LIST_HEAD(&mrt->mfc_unres_queue);