]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
amt: Store struct sock in struct amt_dev.
authorKuniyuki Iwashima <kuniyu@google.com>
Sat, 2 May 2026 03:13:05 +0000 (03:13 +0000)
committerJakub Kicinski <kuba@kernel.org>
Wed, 6 May 2026 00:47:05 +0000 (17:47 -0700)
amt does not need to access struct socket itself in the fast path;
it only reads struct sock, and struct socket is only used for tunnel
setup and teardown.

Let's store struct sock directly in struct amt.

amt_dev_stop() is called as dev->netdev_ops->ndo_stop().
synchronize_net() in unregister_netdevice_many_notify() ensures
that inflight amt RX fast paths finish before amt_dev is freed.

amt no longer needs synchronize_rcu() in udp_tunnel_sock_release().

Note that amt_dev_stop() looks buggy; cancel_delayed_work_sync()
should be called after udp_tunnel_sock_release().

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Link: https://patch.msgid.link/20260502031401.3557229-13-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/net/amt.c
include/net/amt.h

index c03aa7c207e648893928c95fa3cef7a2dfb0864d..724a8163a5142a6835950abb63d80f29417b2654 100644 (file)
@@ -614,24 +614,24 @@ static void amt_send_discovery(struct amt_dev *amt)
 {
        struct amt_header_discovery *amtd;
        int hlen, tlen, offset;
-       struct socket *sock;
        struct udphdr *udph;
        struct sk_buff *skb;
        struct iphdr *iph;
        struct rtable *rt;
        struct flowi4 fl4;
+       struct sock *sk;
        u32 len;
        int err;
 
        rcu_read_lock();
-       sock = rcu_dereference(amt->sock);
-       if (!sock)
+       sk = rcu_dereference(amt->sk);
+       if (!sk)
                goto out;
 
        if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
                goto out;
 
-       rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
+       rt = ip_route_output_ports(amt->net, &fl4, sk,
                                   amt->discovery_ip, amt->local_ip,
                                   amt->gw_port, amt->relay_port,
                                   IPPROTO_UDP, 0,
@@ -690,7 +690,7 @@ static void amt_send_discovery(struct amt_dev *amt)
        skb->ip_summed = CHECKSUM_NONE;
        ip_select_ident(amt->net, skb, NULL);
        ip_send_check(iph);
-       err = ip_local_out(amt->net, sock->sk, skb);
+       err = ip_local_out(amt->net, sk, skb);
        if (unlikely(net_xmit_eval(err)))
                amt->dev->stats.tx_errors++;
 
@@ -703,24 +703,24 @@ static void amt_send_request(struct amt_dev *amt, bool v6)
 {
        struct amt_header_request *amtrh;
        int hlen, tlen, offset;
-       struct socket *sock;
        struct udphdr *udph;
        struct sk_buff *skb;
        struct iphdr *iph;
        struct rtable *rt;
        struct flowi4 fl4;
+       struct sock *sk;
        u32 len;
        int err;
 
        rcu_read_lock();
-       sock = rcu_dereference(amt->sock);
-       if (!sock)
+       sk = rcu_dereference(amt->sk);
+       if (!sk)
                goto out;
 
        if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
                goto out;
 
-       rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
+       rt = ip_route_output_ports(amt->net, &fl4, sk,
                                   amt->remote_ip, amt->local_ip,
                                   amt->gw_port, amt->relay_port,
                                   IPPROTO_UDP, 0,
@@ -781,7 +781,7 @@ static void amt_send_request(struct amt_dev *amt, bool v6)
        skb->ip_summed = CHECKSUM_NONE;
        ip_select_ident(amt->net, skb, NULL);
        ip_send_check(iph);
-       err = ip_local_out(amt->net, sock->sk, skb);
+       err = ip_local_out(amt->net, sk, skb);
        if (unlikely(net_xmit_eval(err)))
                amt->dev->stats.tx_errors++;
 
@@ -1000,14 +1000,14 @@ static bool amt_send_membership_update(struct amt_dev *amt,
                                       bool v6)
 {
        struct amt_header_membership_update *amtmu;
-       struct socket *sock;
        struct iphdr *iph;
        struct flowi4 fl4;
        struct rtable *rt;
+       struct sock *sk;
        int err;
 
-       sock = rcu_dereference_bh(amt->sock);
-       if (!sock)
+       sk = rcu_dereference_bh(amt->sk);
+       if (!sk)
                return true;
 
        err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
@@ -1039,7 +1039,7 @@ static bool amt_send_membership_update(struct amt_dev *amt,
                skb_set_inner_protocol(skb, htons(ETH_P_IP));
        else
                skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
-       udp_tunnel_xmit_skb(rt, sock->sk, skb,
+       udp_tunnel_xmit_skb(rt, sk, skb,
                            fl4.saddr,
                            fl4.daddr,
                            AMT_TOS,
@@ -1060,14 +1060,14 @@ static void amt_send_multicast_data(struct amt_dev *amt,
                                    bool v6)
 {
        struct amt_header_mcast_data *amtmd;
-       struct socket *sock;
        struct sk_buff *skb;
        struct iphdr *iph;
        struct flowi4 fl4;
        struct rtable *rt;
+       struct sock *sk;
 
-       sock = rcu_dereference_bh(amt->sock);
-       if (!sock)
+       sk = rcu_dereference_bh(amt->sk);
+       if (!sk)
                return;
 
        skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
@@ -1097,7 +1097,7 @@ static void amt_send_multicast_data(struct amt_dev *amt,
                skb_set_inner_protocol(skb, htons(ETH_P_IP));
        else
                skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
-       udp_tunnel_xmit_skb(rt, sock->sk, skb,
+       udp_tunnel_xmit_skb(rt, sk, skb,
                            fl4.saddr,
                            fl4.daddr,
                            AMT_TOS,
@@ -1116,13 +1116,13 @@ static bool amt_send_membership_query(struct amt_dev *amt,
                                      bool v6)
 {
        struct amt_header_membership_query *amtmq;
-       struct socket *sock;
        struct rtable *rt;
        struct flowi4 fl4;
+       struct sock *sk;
        int err;
 
-       sock = rcu_dereference_bh(amt->sock);
-       if (!sock)
+       sk = rcu_dereference_bh(amt->sk);
+       if (!sk)
                return true;
 
        err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
@@ -1156,7 +1156,7 @@ static bool amt_send_membership_query(struct amt_dev *amt,
                skb_set_inner_protocol(skb, htons(ETH_P_IP));
        else
                skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
-       udp_tunnel_xmit_skb(rt, sock->sk, skb,
+       udp_tunnel_xmit_skb(rt, sk, skb,
                            fl4.saddr,
                            fl4.daddr,
                            AMT_TOS,
@@ -2554,24 +2554,24 @@ static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
 {
        struct amt_header_advertisement *amta;
        int hlen, tlen, offset;
-       struct socket *sock;
        struct udphdr *udph;
        struct sk_buff *skb;
        struct iphdr *iph;
        struct rtable *rt;
        struct flowi4 fl4;
+       struct sock *sk;
        u32 len;
        int err;
 
        rcu_read_lock();
-       sock = rcu_dereference(amt->sock);
-       if (!sock)
+       sk = rcu_dereference(amt->sk);
+       if (!sk)
                goto out;
 
        if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
                goto out;
 
-       rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
+       rt = ip_route_output_ports(amt->net, &fl4, sk,
                                   daddr, amt->local_ip,
                                   dport, amt->relay_port,
                                   IPPROTO_UDP, 0,
@@ -2631,7 +2631,7 @@ static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
        skb->ip_summed = CHECKSUM_NONE;
        ip_select_ident(amt->net, skb, NULL);
        ip_send_check(iph);
-       err = ip_local_out(amt->net, sock->sk, skb);
+       err = ip_local_out(amt->net, sk, skb);
        if (unlikely(net_xmit_eval(err)))
                amt->dev->stats.tx_errors++;
 
@@ -2944,7 +2944,7 @@ drop:
        return 0;
 }
 
-static struct socket *amt_create_sock(struct net *net, __be16 port)
+static struct sock *amt_create_sock(struct net *net, __be16 port)
 {
        struct udp_port_cfg udp_conf;
        struct socket *sock;
@@ -2960,17 +2960,17 @@ static struct socket *amt_create_sock(struct net *net, __be16 port)
        if (err < 0)
                return ERR_PTR(err);
 
-       return sock;
+       return sock->sk;
 }
 
 static int amt_socket_create(struct amt_dev *amt)
 {
        struct udp_tunnel_sock_cfg tunnel_cfg;
-       struct socket *sock;
+       struct sock *sk;
 
-       sock = amt_create_sock(amt->net, amt->relay_port);
-       if (IS_ERR(sock))
-               return PTR_ERR(sock);
+       sk = amt_create_sock(amt->net, amt->relay_port);
+       if (IS_ERR(sk))
+               return PTR_ERR(sk);
 
        /* Mark socket as an encapsulation socket */
        memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
@@ -2979,9 +2979,9 @@ static int amt_socket_create(struct amt_dev *amt)
        tunnel_cfg.encap_rcv = amt_rcv;
        tunnel_cfg.encap_err_lookup = amt_err_lookup;
        tunnel_cfg.encap_destroy = NULL;
-       setup_udp_tunnel_sock(amt->net, sock->sk, &tunnel_cfg);
+       setup_udp_tunnel_sock(amt->net, sk, &tunnel_cfg);
 
-       rcu_assign_pointer(amt->sock, sock);
+       rcu_assign_pointer(amt->sk, sk);
        return 0;
 }
 
@@ -3019,8 +3019,8 @@ static int amt_dev_stop(struct net_device *dev)
 {
        struct amt_dev *amt = netdev_priv(dev);
        struct amt_tunnel_list *tunnel, *tmp;
-       struct socket *sock;
        struct sk_buff *skb;
+       struct sock *sk;
        int i;
 
        cancel_delayed_work_sync(&amt->req_wq);
@@ -3028,11 +3028,11 @@ static int amt_dev_stop(struct net_device *dev)
        cancel_delayed_work_sync(&amt->secret_wq);
 
        /* shutdown */
-       sock = rtnl_dereference(amt->sock);
-       RCU_INIT_POINTER(amt->sock, NULL);
+       sk = rtnl_dereference(amt->sk);
+       RCU_INIT_POINTER(amt->sk, NULL);
        synchronize_net();
-       if (sock)
-               udp_tunnel_sock_release(sock->sk);
+       if (sk)
+               udp_tunnel_sock_release(sk);
 
        cancel_work_sync(&amt->event_wq);
        for (i = 0; i < AMT_MAX_EVENTS; i++) {
index c881bc8b673b3c47200471f9da73cbce4a6d6fd4..a0255491f5b05ead2caaabb8b50362ebf8d3225b 100644 (file)
@@ -331,7 +331,7 @@ struct amt_dev {
        enum amt_status         status;
        /* Generated key */
        siphash_key_t           key;
-       struct socket     __rcu *sock;
+       struct sock       __rcu *sk;
        u32                     max_groups;
        u32                     max_sources;
        u32                     hash_buckets;