]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
net: annotate data-races around sk->sk_forward_alloc
authorEric Dumazet <edumazet@google.com>
Thu, 31 Aug 2023 13:52:09 +0000 (13:52 +0000)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Tue, 19 Sep 2023 10:28:01 +0000 (12:28 +0200)
[ Upstream commit 5e6300e7b3a4ab5b72a82079753868e91fbf9efc ]

Every time sk->sk_forward_alloc is read locklessly,
add a READ_ONCE().

Add sk_forward_alloc_add() helper to centralize updates,
to reduce number of WRITE_ONCE().

Fixes: 1da177e4c3f4 ("Linux-2.6.12-rc2")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Sasha Levin <sashal@kernel.org>
include/net/sock.h
net/core/sock.c
net/ipv4/tcp_output.c
net/ipv4/udp.c
net/mptcp/protocol.c

index d1f936ed97556aba5d38983ce505bb59698bf20a..fe695e8bfe289db4f52ea824ca5a2f851c986b42 100644 (file)
@@ -1049,6 +1049,12 @@ static inline void sk_wmem_queued_add(struct sock *sk, int val)
        WRITE_ONCE(sk->sk_wmem_queued, sk->sk_wmem_queued + val);
 }
 
+static inline void sk_forward_alloc_add(struct sock *sk, int val)
+{
+       /* Paired with lockless reads of sk->sk_forward_alloc */
+       WRITE_ONCE(sk->sk_forward_alloc, sk->sk_forward_alloc + val);
+}
+
 void sk_stream_write_space(struct sock *sk);
 
 /* OOB backlog add */
@@ -1401,7 +1407,7 @@ static inline int sk_forward_alloc_get(const struct sock *sk)
        if (sk->sk_prot->forward_alloc_get)
                return sk->sk_prot->forward_alloc_get(sk);
 #endif
-       return sk->sk_forward_alloc;
+       return READ_ONCE(sk->sk_forward_alloc);
 }
 
 static inline bool __sk_stream_memory_free(const struct sock *sk, int wake)
@@ -1697,14 +1703,14 @@ static inline void sk_mem_charge(struct sock *sk, int size)
 {
        if (!sk_has_account(sk))
                return;
-       sk->sk_forward_alloc -= size;
+       sk_forward_alloc_add(sk, -size);
 }
 
 static inline void sk_mem_uncharge(struct sock *sk, int size)
 {
        if (!sk_has_account(sk))
                return;
-       sk->sk_forward_alloc += size;
+       sk_forward_alloc_add(sk, size);
        sk_mem_reclaim(sk);
 }
 
index 6ff58fa5f41ed558f0a2ae8dbccb6e21be41e191..aa628c6314f6434977dad0c971edfb67efcf918f 100644 (file)
@@ -1034,7 +1034,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
                mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
                return -ENOMEM;
        }
-       sk->sk_forward_alloc += pages << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
 
        WRITE_ONCE(sk->sk_reserved_mem,
                   sk->sk_reserved_mem + (pages << PAGE_SHIFT));
@@ -3082,10 +3082,10 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
 {
        int ret, amt = sk_mem_pages(size);
 
-       sk->sk_forward_alloc += amt << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
        ret = __sk_mem_raise_allocated(sk, size, amt, kind);
        if (!ret)
-               sk->sk_forward_alloc -= amt << PAGE_SHIFT;
+               sk_forward_alloc_add(sk, -(amt << PAGE_SHIFT));
        return ret;
 }
 EXPORT_SYMBOL(__sk_mem_schedule);
@@ -3117,7 +3117,7 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount)
 void __sk_mem_reclaim(struct sock *sk, int amount)
 {
        amount >>= PAGE_SHIFT;
-       sk->sk_forward_alloc -= amount << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, -(amount << PAGE_SHIFT));
        __sk_mem_reduce_allocated(sk, amount);
 }
 EXPORT_SYMBOL(__sk_mem_reclaim);
index 26bd039f9296f932029ffd7c21b78414a5341699..dc3166e56169f6610786877c0938d67923a81db8 100644 (file)
@@ -3380,7 +3380,7 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
        if (delta <= 0)
                return;
        amt = sk_mem_pages(delta);
-       sk->sk_forward_alloc += amt << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
        sk_memory_allocated_add(sk, amt);
 
        if (mem_cgroup_sockets_enabled && sk->sk_memcg)
index 42c1f7d9a980a2565eae82c5bc22b14c7636e3fc..b2aa7777521f6e7e2f6f70887b5bbe0e4d758327 100644 (file)
@@ -1474,9 +1474,9 @@ static void udp_rmem_release(struct sock *sk, int size, int partial,
                spin_lock(&sk_queue->lock);
 
 
-       sk->sk_forward_alloc += size;
+       sk_forward_alloc_add(sk, size);
        amt = (sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1);
-       sk->sk_forward_alloc -= amt;
+       sk_forward_alloc_add(sk, -amt);
 
        if (amt)
                __sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT);
@@ -1582,7 +1582,7 @@ int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb)
                sk->sk_forward_alloc += delta;
        }
 
-       sk->sk_forward_alloc -= size;
+       sk_forward_alloc_add(sk, -size);
 
        /* no need to setup a destructor, we will explicitly release the
         * forward allocated memory on dequeue
index 61fefa1a82db2851bd5ce7bc920da3500e6b1237..573db9c2bc1cd762437501a03efd1786e1e3d0c5 100644 (file)
@@ -1802,7 +1802,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
                }
 
                /* data successfully copied into the write queue */
-               sk->sk_forward_alloc -= total_ts;
+               sk_forward_alloc_add(sk, -total_ts);
                copied += psize;
                dfrag->data_len += psize;
                frag_truesize += psize;
@@ -3278,7 +3278,7 @@ void mptcp_destroy_common(struct mptcp_sock *msk, unsigned int flags)
        /* move all the rx fwd alloc into the sk_mem_reclaim_final in
         * inet_sock_destruct() will dispose it
         */
-       sk->sk_forward_alloc += msk->rmem_fwd_alloc;
+       sk_forward_alloc_add(sk, msk->rmem_fwd_alloc);
        msk->rmem_fwd_alloc = 0;
        mptcp_token_destroy(msk);
        mptcp_pm_free_anno_list(msk);
@@ -3562,7 +3562,7 @@ static void mptcp_shutdown(struct sock *sk, int how)
 
 static int mptcp_forward_alloc_get(const struct sock *sk)
 {
-       return sk->sk_forward_alloc + mptcp_sk(sk)->rmem_fwd_alloc;
+       return READ_ONCE(sk->sk_forward_alloc) + mptcp_sk(sk)->rmem_fwd_alloc;
 }
 
 static int mptcp_ioctl_outq(const struct mptcp_sock *msk, u64 v)