]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
mpls: Add mpls_route_input().
authorKuniyuki Iwashima <kuniyu@google.com>
Wed, 29 Oct 2025 17:33:00 +0000 (17:33 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 4 Nov 2025 01:40:49 +0000 (17:40 -0800)
mpls_route_input_rcu() is called from mpls_forward() and
mpls_getroute().

The former is under RCU, and the latter is under RTNL, so
mpls_route_input_rcu() uses rcu_dereference_rtnl().

Let's use rcu_dereference() in mpls_route_input_rcu() and
add an RTNL variant for mpls_getroute().

Later, we will remove rtnl_dereference() there.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-9-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/mpls/af_mpls.c

index a715b12860e957aa40416761aed58fb739582724..530f7e6f7b3ce077f718d6f6e95454d8ae5f1f09 100644 (file)
@@ -75,16 +75,23 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
                       struct nlmsghdr *nlh, struct net *net, u32 portid,
                       unsigned int nlm_flags);
 
-static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
+static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
 {
-       struct mpls_route *rt = NULL;
+       struct mpls_route __rcu **platform_label;
 
-       if (index < net->mpls.platform_labels) {
-               struct mpls_route __rcu **platform_label =
-                       rcu_dereference_rtnl(net->mpls.platform_label);
-               rt = rcu_dereference_rtnl(platform_label[index]);
-       }
-       return rt;
+       platform_label = rtnl_dereference(net->mpls.platform_label);
+       return rtnl_dereference(platform_label[index]);
+}
+
+static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
+{
+       struct mpls_route __rcu **platform_label;
+
+       if (index >= net->mpls.platform_labels)
+               return NULL;
+
+       platform_label = rcu_dereference(net->mpls.platform_label);
+       return rcu_dereference(platform_label[index]);
 }
 
 bool mpls_output_possible(const struct net_device *dev)
@@ -2373,12 +2380,12 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
        u32 portid = NETLINK_CB(in_skb).portid;
        u32 in_label = LABEL_NOT_SPECIFIED;
        struct nlattr *tb[RTA_MAX + 1];
+       struct mpls_route *rt = NULL;
        u32 labels[MAX_NEW_LABELS];
        struct mpls_shim_hdr *hdr;
        unsigned int hdr_size = 0;
        const struct mpls_nh *nh;
        struct net_device *dev;
-       struct mpls_route *rt;
        struct rtmsg *rtm, *r;
        struct nlmsghdr *nlh;
        struct sk_buff *skb;
@@ -2406,7 +2413,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
                }
        }
 
-       rt = mpls_route_input_rcu(net, in_label);
+       if (in_label < net->mpls.platform_labels)
+               rt = mpls_route_input(net, in_label);
        if (!rt) {
                err = -ENETUNREACH;
                goto errout;