]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
mpls: Use in6_dev_rcu() and dev_net_rcu() in mpls_forward() and mpls_xmit().
authorKuniyuki Iwashima <kuniyu@google.com>
Wed, 29 Oct 2025 17:32:57 +0000 (17:32 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 4 Nov 2025 01:40:47 +0000 (17:40 -0800)
mpls_forward() and mpls_xmit() are called under RCU.

Let's use in6_dev_rcu() and dev_net_rcu() there to annotate
as such.

Now we pass net to mpls_stats_inc_outucastpkts() not to read
dev_net_rcu() twice.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-6-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/mpls/af_mpls.c
net/mpls/internal.h
net/mpls/mpls_iptunnel.c

index c5bbf712f8be0a16c00b6d6e52a4ca24b9afc7c7..efc6c7da5766aa17e091f40c37179cbfbef87cee 100644 (file)
@@ -129,7 +129,8 @@ bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
 }
 EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
 
-void mpls_stats_inc_outucastpkts(struct net_device *dev,
+void mpls_stats_inc_outucastpkts(struct net *net,
+                                struct net_device *dev,
                                 const struct sk_buff *skb)
 {
        struct mpls_dev *mdev;
@@ -141,13 +142,13 @@ void mpls_stats_inc_outucastpkts(struct net_device *dev,
                                           tx_packets,
                                           tx_bytes);
        } else if (skb->protocol == htons(ETH_P_IP)) {
-               IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len);
+               IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len);
 #if IS_ENABLED(CONFIG_IPV6)
        } else if (skb->protocol == htons(ETH_P_IPV6)) {
-               struct inet6_dev *in6dev = __in6_dev_get(dev);
+               struct inet6_dev *in6dev = in6_dev_rcu(dev);
 
                if (in6dev)
-                       IP6_UPD_PO_STATS(dev_net(dev), in6dev,
+                       IP6_UPD_PO_STATS(net, in6dev,
                                         IPSTATS_MIB_OUT, skb->len);
 #endif
        }
@@ -342,7 +343,7 @@ static bool mpls_egress(struct net *net, struct mpls_route *rt,
 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
                        struct packet_type *pt, struct net_device *orig_dev)
 {
-       struct net *net = dev_net(dev);
+       struct net *net = dev_net_rcu(dev);
        struct mpls_shim_hdr *hdr;
        const struct mpls_nh *nh;
        struct mpls_route *rt;
@@ -434,7 +435,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
        dec.ttl -= 1;
        if (unlikely(!new_header_size && dec.bos)) {
                /* Penultimate hop popping */
-               if (!mpls_egress(dev_net(out_dev), rt, skb, dec))
+               if (!mpls_egress(net, rt, skb, dec))
                        goto err;
        } else {
                bool bos;
@@ -451,7 +452,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
                }
        }
 
-       mpls_stats_inc_outucastpkts(out_dev, skb);
+       mpls_stats_inc_outucastpkts(net, out_dev, skb);
 
        /* If via wasn't specified then send out using device address */
        if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
index 3a5feca27d6af5f6e4ed1cf1e2f93f108fa52faa..e491427ea08aea5571df7ec141737ae36c563bbe 100644 (file)
@@ -197,7 +197,8 @@ int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels,
 bool mpls_output_possible(const struct net_device *dev);
 unsigned int mpls_dev_mtu(const struct net_device *dev);
 bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu);
-void mpls_stats_inc_outucastpkts(struct net_device *dev,
+void mpls_stats_inc_outucastpkts(struct net *net,
+                                struct net_device *dev,
                                 const struct sk_buff *skb);
 
 #endif /* MPLS_INTERNAL_H */
index 6e73da94af7fba2da61567f97d6cc4bae55b3a5d..cfbab7b2fec739a572de4a689dff0b442587fb3c 100644 (file)
@@ -53,7 +53,7 @@ static int mpls_xmit(struct sk_buff *skb)
 
        /* Find the output device */
        out_dev = dst->dev;
-       net = dev_net(out_dev);
+       net = dev_net_rcu(out_dev);
 
        if (!mpls_output_possible(out_dev) ||
            !dst->lwtstate || skb_warn_if_lro(skb))
@@ -128,7 +128,7 @@ static int mpls_xmit(struct sk_buff *skb)
                bos = false;
        }
 
-       mpls_stats_inc_outucastpkts(out_dev, skb);
+       mpls_stats_inc_outucastpkts(net, out_dev, skb);
 
        if (rt) {
                if (rt->rt_gw_family == AF_INET6)