]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
ip6mr: Free mr_table after RCU grace period.
authorKuniyuki Iwashima <kuniyu@google.com>
Thu, 4 Jun 2026 22:46:26 +0000 (22:46 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 9 Jun 2026 00:06:23 +0000 (17:06 -0700)
Since default_device_exit_batch() is called after ->exit_rtnl(),
idev->mc_ifc_work could finally call mroute6_is_socket() under RCU
while ->exit_rtnl() is running. [0]

With CONFIG_IPV6_MROUTE_MULTIPLE_TABLES=n, ip6mr_fib_lookup() does
not check if net->ipv6.mrt6 is NULL.  If ip6mr_net_exit_batch()
set net->ipv6.mrt6 to NULL and freed it, the mrt->mroute_sk access
could result in null-ptr-deref or use-after-free.

Let's prepare for that situation by applying RCU rule to ip6mr
table similarly.

!check_net(net) is added in ip6mr_cache_unresolved() and
mroute_clean_tables() to synchronise the two by mfc_unres_lock
so that ip6mr_cache_unresolved() will not queue skb after
mroute_clean_tables() purged &mrt->mfc_unres_queue.

rcu_read_lock() in reg_vif_xmit() is moved up to cover
ip6mr_fib_lookup() as with ipmr.

Link: https://lore.kernel.org/netdev/20260407184202.34cfe2d6@kernel.org/
Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Link: https://patch.msgid.link/20260604224712.3209821-9-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/ipv6/ip6mr.c

index 8c8ad1753c75813dd55830f50146d40bd127fcf0..ddbe06397d9c5c7799c8573fc00d78daaca3a99a 100644 (file)
@@ -136,16 +136,6 @@ static struct mr_table *__ip6mr_get_table(struct net *net, u32 id)
        return NULL;
 }
 
-static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
-{
-       struct mr_table *mrt;
-
-       rcu_read_lock();
-       mrt = __ip6mr_get_table(net, id);
-       rcu_read_unlock();
-       return mrt;
-}
-
 static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
                            struct mr_table **mrt)
 {
@@ -274,7 +264,7 @@ static void __net_exit ip6mr_rules_exit(struct net *net)
 
        ASSERT_RTNL();
        list_for_each_entry_safe(mrt, next, &net->ipv6.mr6_tables, list) {
-               list_del(&mrt->list);
+               list_del_rcu(&mrt->list);
                ip6mr_free_table(mrt);
        }
        fib_rules_unregister(net->ipv6.mr6_rules_ops);
@@ -298,28 +288,30 @@ bool ip6mr_rule_default(const struct fib_rule *rule)
 }
 EXPORT_SYMBOL(ip6mr_rule_default);
 #else
-#define ip6mr_for_each_table(mrt, net) \
-       for (mrt = net->ipv6.mrt6; mrt; mrt = NULL)
-
 static struct mr_table *ip6mr_mr_table_iter(struct net *net,
                                            struct mr_table *mrt)
 {
        if (!mrt)
-               return net->ipv6.mrt6;
+               return rcu_dereference(net->ipv6.mrt6);
        return NULL;
 }
 
-static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+static struct mr_table *__ip6mr_get_table(struct net *net, u32 id)
 {
-       return net->ipv6.mrt6;
+       return rcu_dereference_check(net->ipv6.mrt6,
+                                    lockdep_rtnl_is_held() ||
+                                    !rcu_access_pointer(net->ipv6.mrt6));
 }
 
-#define __ip6mr_get_table ip6mr_get_table
+#define ip6mr_for_each_table(mrt, net)                         \
+       for (mrt = __ip6mr_get_table(net, 0); mrt; mrt = NULL)
 
 static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
                            struct mr_table **mrt)
 {
-       *mrt = net->ipv6.mrt6;
+       *mrt = rcu_dereference(net->ipv6.mrt6);
+       if (!*mrt)
+               return -EAGAIN;
        return 0;
 }
 
@@ -330,15 +322,19 @@ static int __net_init ip6mr_rules_init(struct net *net)
        mrt = ip6mr_new_table(net, RT6_TABLE_DFLT);
        if (IS_ERR(mrt))
                return PTR_ERR(mrt);
-       net->ipv6.mrt6 = mrt;
+
+       rcu_assign_pointer(net->ipv6.mrt6, mrt);
        return 0;
 }
 
 static void __net_exit ip6mr_rules_exit(struct net *net)
 {
+       struct mr_table *mrt = rcu_dereference_protected(net->ipv6.mrt6, 1);
+
        ASSERT_RTNL();
-       ip6mr_free_table(net->ipv6.mrt6);
-       net->ipv6.mrt6 = NULL;
+
+       RCU_INIT_POINTER(net->ipv6.mrt6, NULL);
+       ip6mr_free_table(mrt);
 }
 
 static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb,
@@ -353,6 +349,17 @@ static unsigned int ip6mr_rules_seq_read(const struct net *net)
 }
 #endif
 
+static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+{
+       struct mr_table *mrt;
+
+       rcu_read_lock();
+       mrt = __ip6mr_get_table(net, id);
+       rcu_read_unlock();
+
+       return mrt;
+}
+
 static int ip6mr_hash_cmp(struct rhashtable_compare_arg *arg,
                          const void *ptr)
 {
@@ -411,8 +418,8 @@ static void ip6mr_free_table(struct mr_table *mrt)
        timer_shutdown_sync(&mrt->ipmr_expire_timer);
        mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC |
                                 MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC);
-       rhltable_destroy(&mrt->mfc_hash);
-       kfree(mrt);
+
+       mr_table_free(mrt);
 }
 
 #ifdef CONFIG_PROC_FS
@@ -623,18 +630,22 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb,
        if (!pskb_inet_may_pull(skb))
                goto tx_err;
 
+       rcu_read_lock();
+
        if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
-               goto tx_err;
+               goto tx_lookup_err;
 
        DEV_STATS_ADD(dev, tx_bytes, skb->len);
        DEV_STATS_INC(dev, tx_packets);
-       rcu_read_lock();
+
        ip6mr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num),
                           MRT6MSG_WHOLEPKT);
        rcu_read_unlock();
        kfree_skb(skb);
        return NETDEV_TX_OK;
 
+tx_lookup_err:
+       rcu_read_unlock();
 tx_err:
        DEV_STATS_INC(dev, tx_errors);
        kfree_skb(skb);
@@ -1157,11 +1168,18 @@ static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt,
 static int ip6mr_cache_unresolved(struct mr_table *mrt, mifi_t mifi,
                                  struct sk_buff *skb, struct net_device *dev)
 {
-       struct mfc6_cache *c;
+       struct net *net = read_pnet(&mrt->net);
+       struct mfc6_cache *c = NULL;
        bool found = false;
        int err;
 
        spin_lock_bh(&mfc_unres_lock);
+
+       if (!check_net(net)) {
+               err = -EINVAL;
+               goto err;
+       }
+
        list_for_each_entry(c, &mrt->mfc_unres_queue, _c.list) {
                if (ipv6_addr_equal(&c->mf6c_mcastgrp, &ipv6_hdr(skb)->daddr) &&
                    ipv6_addr_equal(&c->mf6c_origin, &ipv6_hdr(skb)->saddr)) {
@@ -1177,10 +1195,8 @@ static int ip6mr_cache_unresolved(struct mr_table *mrt, mifi_t mifi,
 
                c = ip6mr_cache_alloc_unres();
                if (!c) {
-                       spin_unlock_bh(&mfc_unres_lock);
-
-                       kfree_skb(skb);
-                       return -ENOBUFS;
+                       err = -ENOBUFS;
+                       goto err;
                }
 
                /* Fill in the new cache entry */
@@ -1192,16 +1208,8 @@ static int ip6mr_cache_unresolved(struct mr_table *mrt, mifi_t mifi,
                 *      Reflect first query at pim6sd
                 */
                err = ip6mr_cache_report(mrt, skb, mifi, MRT6MSG_NOCACHE);
-               if (err < 0) {
-                       /* If the report failed throw the cache entry
-                          out - Brad Parker
-                        */
-                       spin_unlock_bh(&mfc_unres_lock);
-
-                       ip6mr_cache_free(c);
-                       kfree_skb(skb);
-                       return err;
-               }
+               if (err < 0)
+                       goto err;
 
                atomic_inc(&mrt->cache_resolve_queue_len);
                list_add(&c->_c.list, &mrt->mfc_unres_queue);
@@ -1212,18 +1220,26 @@ static int ip6mr_cache_unresolved(struct mr_table *mrt, mifi_t mifi,
 
        /* See if we can append the packet */
        if (c->_c.mfc_un.unres.unresolved.qlen > 3) {
-               kfree_skb(skb);
+               c = NULL;
                err = -ENOBUFS;
-       } else {
-               if (dev) {
-                       skb->dev = dev;
-                       skb->skb_iif = dev->ifindex;
-               }
-               skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb);
-               err = 0;
+               goto err;
+       }
+
+       if (dev) {
+               skb->dev = dev;
+               skb->skb_iif = dev->ifindex;
        }
 
+       skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb);
+
+       spin_unlock_bh(&mfc_unres_lock);
+       return 0;
+
+err:
        spin_unlock_bh(&mfc_unres_lock);
+       if (c)
+               ip6mr_cache_free(c);
+       kfree_skb(skb);
        return err;
 }
 
@@ -1534,6 +1550,7 @@ static int ip6mr_mfc_add(struct net *net, struct mr_table *mrt,
 
 static void mroute_clean_tables(struct mr_table *mrt, int flags)
 {
+       struct net *net = read_pnet(&mrt->net);
        struct mr_mfc *c, *tmp;
        LIST_HEAD(list);
        int i;
@@ -1558,8 +1575,7 @@ static void mroute_clean_tables(struct mr_table *mrt, int flags)
                                continue;
                        rhltable_remove(&mrt->mfc_hash, &c->mnode, ip6mr_rht_params);
                        list_del_rcu(&c->list);
-                       call_ip6mr_mfc_entry_notifiers(read_pnet(&mrt->net),
-                                                      FIB_EVENT_ENTRY_DEL,
+                       call_ip6mr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_DEL,
                                                       (struct mfc6_cache *)c, mrt->id);
                        mr6_netlink_event(mrt, (struct mfc6_cache *)c, RTM_DELROUTE);
                        mr_cache_put(c);
@@ -1567,7 +1583,8 @@ static void mroute_clean_tables(struct mr_table *mrt, int flags)
        }
 
        if (flags & MRT6_FLUSH_MFC) {
-               if (atomic_read(&mrt->cache_resolve_queue_len) != 0) {
+               if (atomic_read(&mrt->cache_resolve_queue_len) != 0 ||
+                   !check_net(net)) {
                        spin_lock_bh(&mfc_unres_lock);
                        list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) {
                                list_del(&c->list);