]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
mpls: Protect net->mpls.platform_label with a per-netns mutex.
authorKuniyuki Iwashima <kuniyu@google.com>
Wed, 29 Oct 2025 17:33:04 +0000 (17:33 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 4 Nov 2025 01:40:53 +0000 (17:40 -0800)
MPLS (re)uses RTNL to protect net->mpls.platform_label,
but the lock does not need to be RTNL at all.

Let's protect net->mpls.platform_label with a dedicated
per-netns mutex.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-13-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/net/netns/mpls.h
net/mpls/af_mpls.c
net/mpls/internal.h

index 19ad2574b267ca4bb08980cae6ed05bfa0fc20ae..6682e51513efa5d0b48d41b4e9eeb924eada5794 100644 (file)
@@ -16,6 +16,7 @@ struct netns_mpls {
        int default_ttl;
        size_t platform_labels;
        struct mpls_route __rcu * __rcu *platform_label;
+       struct mutex platform_mutex;
 
        struct ctl_table_header *ctl;
 };
index 49fd15232dbec75004cbf22809d9b5fa0a31ae85..d0d047dd2245f6ab7dccd8907b6b1c1a431c03ad 100644 (file)
@@ -79,8 +79,8 @@ static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
 {
        struct mpls_route __rcu **platform_label;
 
-       platform_label = rtnl_dereference(net->mpls.platform_label);
-       return rtnl_dereference(platform_label[index]);
+       platform_label = mpls_dereference(net, net->mpls.platform_label);
+       return mpls_dereference(net, platform_label[index]);
 }
 
 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
@@ -578,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index,
        struct mpls_route __rcu **platform_label;
        struct mpls_route *rt;
 
-       ASSERT_RTNL();
-
-       platform_label = rtnl_dereference(net->mpls.platform_label);
-       rt = rtnl_dereference(platform_label[index]);
+       platform_label = mpls_dereference(net, net->mpls.platform_label);
+       rt = mpls_dereference(net, platform_label[index]);
        rcu_assign_pointer(platform_label[index], new);
 
        mpls_notify_route(net, index, rt, new, info);
@@ -1472,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev)
        int err = -ENOMEM;
        int i;
 
-       ASSERT_RTNL();
-
        mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
        if (!mdev)
                return ERR_PTR(err);
@@ -1633,6 +1629,8 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
        unsigned int flags;
        int err;
 
+       mutex_lock(&net->mpls.platform_mutex);
+
        if (event == NETDEV_REGISTER) {
                mdev = mpls_add_dev(dev);
                if (IS_ERR(mdev)) {
@@ -1695,9 +1693,11 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
        }
 
 out:
+       mutex_unlock(&net->mpls.platform_mutex);
        return NOTIFY_OK;
 
 err:
+       mutex_unlock(&net->mpls.platform_mutex);
        return notifier_from_errno(err);
 }
 
@@ -1973,6 +1973,7 @@ errout:
 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
                             struct netlink_ext_ack *extack)
 {
+       struct net *net = sock_net(skb->sk);
        struct mpls_route_config *cfg;
        int err;
 
@@ -1984,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (err < 0)
                goto out;
 
+       mutex_lock(&net->mpls.platform_mutex);
        err = mpls_route_del(cfg, extack);
+       mutex_unlock(&net->mpls.platform_mutex);
 out:
        kfree(cfg);
 
@@ -1995,6 +1998,7 @@ out:
 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
                             struct netlink_ext_ack *extack)
 {
+       struct net *net = sock_net(skb->sk);
        struct mpls_route_config *cfg;
        int err;
 
@@ -2006,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (err < 0)
                goto out;
 
+       mutex_lock(&net->mpls.platform_mutex);
        err = mpls_route_add(cfg, extack);
+       mutex_unlock(&net->mpls.platform_mutex);
 out:
        kfree(cfg);
 
@@ -2407,6 +2413,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
        u8 n_labels;
        int err;
 
+       mutex_lock(&net->mpls.platform_mutex);
+
        err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
        if (err < 0)
                goto errout;
@@ -2450,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
                        goto errout_free;
                }
 
-               return rtnl_unicast(skb, net, portid);
+               err = rtnl_unicast(skb, net, portid);
+               goto errout;
        }
 
        if (tb[RTA_NEWDST]) {
@@ -2542,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
 
        err = rtnl_unicast(skb, net, portid);
 errout:
+       mutex_unlock(&net->mpls.platform_mutex);
        return err;
 
 nla_put_failure:
        nlmsg_cancel(skb, nlh);
        err = -EMSGSIZE;
 errout_free:
+       mutex_unlock(&net->mpls.platform_mutex);
        kfree_skb(skb);
        return err;
 }
@@ -2603,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit)
                       lo->addr_len);
        }
 
-       rtnl_lock();
+       mutex_lock(&net->mpls.platform_mutex);
+
        /* Remember the original table */
-       old = rtnl_dereference(net->mpls.platform_label);
+       old = mpls_dereference(net, net->mpls.platform_label);
        old_limit = net->mpls.platform_labels;
 
        /* Free any labels beyond the new table */
@@ -2636,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
        net->mpls.platform_labels = limit;
        rcu_assign_pointer(net->mpls.platform_label, labels);
 
-       rtnl_unlock();
+       mutex_unlock(&net->mpls.platform_mutex);
 
        mpls_rt_free(rt2);
        mpls_rt_free(rt0);
@@ -2709,12 +2721,13 @@ static const struct ctl_table mpls_table[] = {
        },
 };
 
-static int mpls_net_init(struct net *net)
+static __net_init int mpls_net_init(struct net *net)
 {
        size_t table_size = ARRAY_SIZE(mpls_table);
        struct ctl_table *table;
        int i;
 
+       mutex_init(&net->mpls.platform_mutex);
        net->mpls.platform_labels = 0;
        net->mpls.platform_label = NULL;
        net->mpls.ip_ttl_propagate = 1;
@@ -2740,7 +2753,7 @@ static int mpls_net_init(struct net *net)
        return 0;
 }
 
-static void mpls_net_exit(struct net *net)
+static __net_exit void mpls_net_exit(struct net *net)
 {
        struct mpls_route __rcu **platform_label;
        size_t platform_labels;
@@ -2760,16 +2773,20 @@ static void mpls_net_exit(struct net *net)
         * As such no additional rcu synchronization is necessary when
         * freeing the platform_label table.
         */
-       rtnl_lock();
-       platform_label = rtnl_dereference(net->mpls.platform_label);
+       mutex_lock(&net->mpls.platform_mutex);
+
+       platform_label = mpls_dereference(net, net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
+
        for (index = 0; index < platform_labels; index++) {
-               struct mpls_route *rt = rtnl_dereference(platform_label[index]);
-               RCU_INIT_POINTER(platform_label[index], NULL);
+               struct mpls_route *rt;
+
+               rt = mpls_dereference(net, platform_label[index]);
                mpls_notify_route(net, index, rt, NULL, NULL);
                mpls_rt_free(rt);
        }
-       rtnl_unlock();
+
+       mutex_unlock(&net->mpls.platform_mutex);
 
        kvfree(platform_label);
 }
index 0df01a5395eea1ba8efe7b80a4db1ff384d160b6..80cb5bbcd9465d4f256761c777885f00ccf4ea53 100644 (file)
@@ -185,6 +185,11 @@ static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr *
        return result;
 }
 
+#define mpls_dereference(net, p)                                       \
+       rcu_dereference_protected(                                      \
+               (p),                                                    \
+               lockdep_is_held(&(net)->mpls.platform_mutex))
+
 static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
 {
        return rcu_dereference(dev->mpls_ptr);
@@ -193,7 +198,7 @@ static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
 static inline struct mpls_dev *mpls_dev_get(const struct net *net,
                                            const struct net_device *dev)
 {
-       return rcu_dereference_rtnl(dev->mpls_ptr);
+       return mpls_dereference(net, dev->mpls_ptr);
 }
 
 int nla_put_labels(struct sk_buff *skb, int attrtype,  u8 labels,