]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
ipv6: mcast: Don't hold RTNL for MCAST_ socket options.
authorKuniyuki Iwashima <kuniyu@google.com>
Wed, 2 Jul 2025 23:01:26 +0000 (16:01 -0700)
committerJakub Kicinski <kuba@kernel.org>
Wed, 9 Jul 2025 01:32:38 +0000 (18:32 -0700)
In ip6_mc_source() and ip6_mc_msfilter(), per-socket mld data is
protected by lock_sock() and inet6_dev->mc_lock is also held for
some per-interface functions.

ip6_mc_find_dev_rtnl() only depends on RTNL.  If we want to remove
it, we need to check inet6_dev->dead under mc_lock to close the race
with addrconf_ifdown(), as mentioned earlier.

Let's do that and drop RTNL for the rest of MCAST_ socket options.

Note that ip6_mc_msfilter() has unnecessary lock dances and they
are integrated into one to avoid the last-minute error and simplify
the error handling.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250702230210.3115355-10-kuni1840@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/ipv6/ipv6_sockglue.c
net/ipv6/mcast.c

index c8892d54821fb49a2f4308c996f73fe4a293bdae..0c870713b08cefa0eddc03415a861363cbaeb398 100644 (file)
@@ -123,11 +123,6 @@ static bool setsockopt_needs_rtnl(int optname)
        case IPV6_ADDRFORM:
        case IPV6_JOIN_ANYCAST:
        case IPV6_LEAVE_ANYCAST:
-       case MCAST_JOIN_SOURCE_GROUP:
-       case MCAST_LEAVE_SOURCE_GROUP:
-       case MCAST_BLOCK_SOURCE:
-       case MCAST_UNBLOCK_SOURCE:
-       case MCAST_MSFILTER:
                return true;
        }
        return false;
index 5c5f69f23d4a2a7910e77006aec4ecf7ccfd50ba..edae7770bf8c916bdd8992ae208d5f019a212199 100644 (file)
@@ -302,31 +302,36 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
 }
 EXPORT_SYMBOL(ipv6_sock_mc_drop);
 
-static struct inet6_dev *ip6_mc_find_dev_rtnl(struct net *net,
-                                             const struct in6_addr *group,
-                                             int ifindex)
+static struct inet6_dev *ip6_mc_find_dev(struct net *net,
+                                        const struct in6_addr *group,
+                                        int ifindex)
 {
        struct net_device *dev = NULL;
-       struct inet6_dev *idev = NULL;
+       struct inet6_dev *idev;
 
        if (ifindex == 0) {
-               struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, NULL, 0);
+               struct rt6_info *rt;
 
+               rcu_read_lock();
+               rt = rt6_lookup(net, group, NULL, 0, NULL, 0);
                if (rt) {
-                       dev = rt->dst.dev;
+                       dev = dst_dev(&rt->dst);
+                       dev_hold(dev);
                        ip6_rt_put(rt);
                }
+               rcu_read_unlock();
        } else {
-               dev = __dev_get_by_index(net, ifindex);
+               dev = dev_get_by_index(net, ifindex);
        }
-
        if (!dev)
                return NULL;
-       idev = __in6_dev_get(dev);
+
+       idev = in6_dev_get(dev);
+       dev_put(dev);
+
        if (!idev)
                return NULL;
-       if (idev->dead)
-               return NULL;
+
        return idev;
 }
 
@@ -354,16 +359,16 @@ void ipv6_sock_mc_close(struct sock *sk)
 }
 
 int ip6_mc_source(int add, int omode, struct sock *sk,
-       struct group_source_req *pgsr)
+                 struct group_source_req *pgsr)
 {
+       struct ipv6_pinfo *inet6 = inet6_sk(sk);
        struct in6_addr *source, *group;
+       struct net *net = sock_net(sk);
        struct ipv6_mc_socklist *pmc;
-       struct inet6_dev *idev;
-       struct ipv6_pinfo *inet6 = inet6_sk(sk);
        struct ip6_sf_socklist *psl;
-       struct net *net = sock_net(sk);
-       int i, j, rv;
+       struct inet6_dev *idev;
        int leavegroup = 0;
+       int i, j, rv;
        int err;
 
        source = &((struct sockaddr_in6 *)&pgsr->gsr_source)->sin6_addr;
@@ -372,13 +377,19 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
        if (!ipv6_addr_is_multicast(group))
                return -EINVAL;
 
-       idev = ip6_mc_find_dev_rtnl(net, group, pgsr->gsr_interface);
+       idev = ip6_mc_find_dev(net, group, pgsr->gsr_interface);
        if (!idev)
                return -ENODEV;
 
+       mutex_lock(&idev->mc_lock);
+
+       if (idev->dead) {
+               err = -ENODEV;
+               goto done;
+       }
+
        err = -EADDRNOTAVAIL;
 
-       mutex_lock(&idev->mc_lock);
        for_each_pmc_socklock(inet6, sk, pmc) {
                if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
                        continue;
@@ -475,6 +486,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
        ip6_mc_add_src(idev, group, omode, 1, source, 1);
 done:
        mutex_unlock(&idev->mc_lock);
+       in6_dev_put(idev);
        if (leavegroup)
                err = ipv6_sock_mc_drop(sk, pgsr->gsr_interface, group);
        return err;
@@ -483,12 +495,12 @@ done:
 int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
                    struct sockaddr_storage *list)
 {
-       const struct in6_addr *group;
-       struct ipv6_mc_socklist *pmc;
-       struct inet6_dev *idev;
        struct ipv6_pinfo *inet6 = inet6_sk(sk);
        struct ip6_sf_socklist *newpsl, *psl;
        struct net *net = sock_net(sk);
+       const struct in6_addr *group;
+       struct ipv6_mc_socklist *pmc;
+       struct inet6_dev *idev;
        int leavegroup = 0;
        int i, err;
 
@@ -500,10 +512,17 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
            gsf->gf_fmode != MCAST_EXCLUDE)
                return -EINVAL;
 
-       idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface);
+       idev = ip6_mc_find_dev(net, group, gsf->gf_interface);
        if (!idev)
                return -ENODEV;
 
+       mutex_lock(&idev->mc_lock);
+
+       if (idev->dead) {
+               err = -ENODEV;
+               goto done;
+       }
+
        err = 0;
 
        if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) {
@@ -536,24 +555,19 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
                        psin6 = (struct sockaddr_in6 *)list;
                        newpsl->sl_addr[i] = psin6->sin6_addr;
                }
-               mutex_lock(&idev->mc_lock);
+
                err = ip6_mc_add_src(idev, group, gsf->gf_fmode,
                                     newpsl->sl_count, newpsl->sl_addr, 0);
                if (err) {
-                       mutex_unlock(&idev->mc_lock);
                        sock_kfree_s(sk, newpsl, struct_size(newpsl, sl_addr,
                                                             newpsl->sl_max));
                        goto done;
                }
-               mutex_unlock(&idev->mc_lock);
        } else {
                newpsl = NULL;
-               mutex_lock(&idev->mc_lock);
                ip6_mc_add_src(idev, group, gsf->gf_fmode, 0, NULL, 0);
-               mutex_unlock(&idev->mc_lock);
        }
 
-       mutex_lock(&idev->mc_lock);
        psl = sock_dereference(pmc->sflist, sk);
        if (psl) {
                ip6_mc_del_src(idev, group, pmc->sfmode,
@@ -563,12 +577,14 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
        } else {
                ip6_mc_del_src(idev, group, pmc->sfmode, 0, NULL, 0);
        }
+
        rcu_assign_pointer(pmc->sflist, newpsl);
-       mutex_unlock(&idev->mc_lock);
        kfree_rcu(psl, rcu);
        pmc->sfmode = gsf->gf_fmode;
        err = 0;
 done:
+       mutex_unlock(&idev->mc_lock);
+       in6_dev_put(idev);
        if (leavegroup)
                err = ipv6_sock_mc_drop(sk, gsf->gf_interface, group);
        return err;