]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
mpls: add seqcount to protect the platform_label{,s} pair
authorSabrina Dubroca <sd@queasysnail.net>
Mon, 23 Mar 2026 23:25:57 +0000 (00:25 +0100)
committerJakub Kicinski <kuba@kernel.org>
Fri, 27 Mar 2026 01:32:14 +0000 (18:32 -0700)
The RCU-protected codepaths (mpls_forward, mpls_dump_routes) can have
an inconsistent view of platform_labels vs platform_label in case of a
concurrent resize (resize_platform_label_table, under
platform_mutex). This can lead to OOB accesses.

This patch adds a seqcount, so that we get a consistent snapshot.

Note that mpls_label_ok is also susceptible to this, so the check
against RTA_DST in rtm_to_route_config, done outside platform_mutex,
is not sufficient. This value gets passed to mpls_label_ok once more
in both mpls_route_add and mpls_route_del, so there is no issue, but
that additional check must not be removed.

Reported-by: Yuan Tan <tanyuan98@outlook.com>
Reported-by: Yifan Wu <yifanwucs@gmail.com>
Reported-by: Juefei Pu <tomapufckgml@gmail.com>
Reported-by: Xin Liu <bird@lzu.edu.cn>
Fixes: 7720c01f3f590 ("mpls: Add a sysctl to control the size of the mpls label table")
Fixes: dde1b38e873c ("mpls: Convert mpls_dump_routes() to RCU.")
Signed-off-by: Sabrina Dubroca <sd@queasysnail.net>
Link: https://patch.msgid.link/cd8fca15e3eb7e212b094064cd83652e20fd9d31.1774284088.git.sd@queasysnail.net
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/net/netns/mpls.h
net/mpls/af_mpls.c

index 6682e51513efa5d0b48d41b4e9eeb924eada5794..2073cbac2afb51fd319784af2866319d4aa5d2db 100644 (file)
@@ -17,6 +17,7 @@ struct netns_mpls {
        size_t platform_labels;
        struct mpls_route __rcu * __rcu *platform_label;
        struct mutex platform_mutex;
+       seqcount_mutex_t platform_label_seq;
 
        struct ctl_table_header *ctl;
 };
index d5417688f69e6316fa368f05843d814cb648dbca..18d3da8ab384874de5aea8c2b3243e277290cfba 100644 (file)
@@ -83,14 +83,30 @@ static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
        return mpls_dereference(net, platform_label[index]);
 }
 
+static struct mpls_route __rcu **mpls_platform_label_rcu(struct net *net, size_t *platform_labels)
+{
+       struct mpls_route __rcu **platform_label;
+       unsigned int sequence;
+
+       do {
+               sequence = read_seqcount_begin(&net->mpls.platform_label_seq);
+               platform_label = rcu_dereference(net->mpls.platform_label);
+               *platform_labels = net->mpls.platform_labels;
+       } while (read_seqcount_retry(&net->mpls.platform_label_seq, sequence));
+
+       return platform_label;
+}
+
 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
 {
        struct mpls_route __rcu **platform_label;
+       size_t platform_labels;
+
+       platform_label = mpls_platform_label_rcu(net, &platform_labels);
 
-       if (index >= net->mpls.platform_labels)
+       if (index >= platform_labels)
                return NULL;
 
-       platform_label = rcu_dereference(net->mpls.platform_label);
        return rcu_dereference(platform_label[index]);
 }
 
@@ -2240,8 +2256,7 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
        if (index < MPLS_LABEL_FIRST_UNRESERVED)
                index = MPLS_LABEL_FIRST_UNRESERVED;
 
-       platform_label = rcu_dereference(net->mpls.platform_label);
-       platform_labels = net->mpls.platform_labels;
+       platform_label = mpls_platform_label_rcu(net, &platform_labels);
 
        if (filter.filter_set)
                flags |= NLM_F_DUMP_FILTERED;
@@ -2645,8 +2660,12 @@ static int resize_platform_label_table(struct net *net, size_t limit)
        }
 
        /* Update the global pointers */
+       local_bh_disable();
+       write_seqcount_begin(&net->mpls.platform_label_seq);
        net->mpls.platform_labels = limit;
        rcu_assign_pointer(net->mpls.platform_label, labels);
+       write_seqcount_end(&net->mpls.platform_label_seq);
+       local_bh_enable();
 
        mutex_unlock(&net->mpls.platform_mutex);
 
@@ -2728,6 +2747,8 @@ static __net_init int mpls_net_init(struct net *net)
        int i;
 
        mutex_init(&net->mpls.platform_mutex);
+       seqcount_mutex_init(&net->mpls.platform_label_seq, &net->mpls.platform_mutex);
+
        net->mpls.platform_labels = 0;
        net->mpls.platform_label = NULL;
        net->mpls.ip_ttl_propagate = 1;