]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
io_uring: consistently use rcu semantics with sqpoll thread
authorKeith Busch <kbusch@kernel.org>
Wed, 11 Jun 2025 20:53:43 +0000 (13:53 -0700)
committerJens Axboe <axboe@kernel.dk>
Thu, 12 Jun 2025 14:17:09 +0000 (08:17 -0600)
The sqpoll thread is dereferenced with rcu read protection in one place,
so it needs to be annotated as an __rcu type, and should consistently
use rcu helpers for access and assignment to make sparse happy.

Since most of the accesses occur under the sqd->lock, we can use
rcu_dereference_protected() without declaring an rcu read section.
Provide a simple helper to get the thread from a locked context.

Fixes: ac0b8b327a5677d ("io_uring: fix use-after-free of sq->thread in __io_uring_show_fdinfo()")
Signed-off-by: Keith Busch <kbusch@kernel.org>
Link: https://lore.kernel.org/r/20250611205343.1821117-1-kbusch@meta.com
[axboe: fold in fix for register.c]
Signed-off-by: Jens Axboe <axboe@kernel.dk>
io_uring/io_uring.c
io_uring/register.c
io_uring/sqpoll.c
io_uring/sqpoll.h

index cf759c172083c543207a24496042be5564a460f0..4e32f808d07df2e683a9b53e97f5a02c420b8cb1 100644 (file)
@@ -2906,7 +2906,7 @@ static __cold void io_ring_exit_work(struct work_struct *work)
                        struct task_struct *tsk;
 
                        io_sq_thread_park(sqd);
-                       tsk = sqd->thread;
+                       tsk = sqpoll_task_locked(sqd);
                        if (tsk && tsk->io_uring && tsk->io_uring->io_wq)
                                io_wq_cancel_cb(tsk->io_uring->io_wq,
                                                io_cancel_ctx_cb, ctx, true);
@@ -3142,7 +3142,7 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
        s64 inflight;
        DEFINE_WAIT(wait);
 
-       WARN_ON_ONCE(sqd && sqd->thread != current);
+       WARN_ON_ONCE(sqd && sqpoll_task_locked(sqd) != current);
 
        if (!current->io_uring)
                return;
index cc23a4c205cd43250f1fc47849052746e5d521e7..a59589249fce7acddebd9f1c2e80ffd2f879abd3 100644 (file)
@@ -273,6 +273,8 @@ static __cold int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
        if (ctx->flags & IORING_SETUP_SQPOLL) {
                sqd = ctx->sq_data;
                if (sqd) {
+                       struct task_struct *tsk;
+
                        /*
                         * Observe the correct sqd->lock -> ctx->uring_lock
                         * ordering. Fine to drop uring_lock here, we hold
@@ -282,8 +284,9 @@ static __cold int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
                        mutex_unlock(&ctx->uring_lock);
                        mutex_lock(&sqd->lock);
                        mutex_lock(&ctx->uring_lock);
-                       if (sqd->thread)
-                               tctx = sqd->thread->io_uring;
+                       tsk = sqpoll_task_locked(sqd);
+                       if (tsk)
+                               tctx = tsk->io_uring;
                }
        } else {
                tctx = current->io_uring;
index 0625a421626f4cf09af1399ca2cf26f162a0aa90..268d2fbe6160c25b668940f482e5086e45727192 100644 (file)
@@ -30,7 +30,7 @@ enum {
 void io_sq_thread_unpark(struct io_sq_data *sqd)
        __releases(&sqd->lock)
 {
-       WARN_ON_ONCE(sqd->thread == current);
+       WARN_ON_ONCE(sqpoll_task_locked(sqd) == current);
 
        /*
         * Do the dance but not conditional clear_bit() because it'd race with
@@ -46,24 +46,32 @@ void io_sq_thread_unpark(struct io_sq_data *sqd)
 void io_sq_thread_park(struct io_sq_data *sqd)
        __acquires(&sqd->lock)
 {
-       WARN_ON_ONCE(data_race(sqd->thread) == current);
+       struct task_struct *tsk;
 
        atomic_inc(&sqd->park_pending);
        set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
        mutex_lock(&sqd->lock);
-       if (sqd->thread)
-               wake_up_process(sqd->thread);
+
+       tsk = sqpoll_task_locked(sqd);
+       if (tsk) {
+               WARN_ON_ONCE(tsk == current);
+               wake_up_process(tsk);
+       }
 }
 
 void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-       WARN_ON_ONCE(sqd->thread == current);
+       struct task_struct *tsk;
+
        WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));
 
        set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
        mutex_lock(&sqd->lock);
-       if (sqd->thread)
-               wake_up_process(sqd->thread);
+       tsk = sqpoll_task_locked(sqd);
+       if (tsk) {
+               WARN_ON_ONCE(tsk == current);
+               wake_up_process(tsk);
+       }
        mutex_unlock(&sqd->lock);
        wait_for_completion(&sqd->exited);
 }
@@ -486,7 +494,10 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
                        goto err_sqpoll;
                }
 
-               sqd->thread = tsk;
+               mutex_lock(&sqd->lock);
+               rcu_assign_pointer(sqd->thread, tsk);
+               mutex_unlock(&sqd->lock);
+
                task_to_put = get_task_struct(tsk);
                ret = io_uring_alloc_task_context(tsk, ctx);
                wake_up_new_task(tsk);
@@ -514,10 +525,13 @@ __cold int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx,
        int ret = -EINVAL;
 
        if (sqd) {
+               struct task_struct *tsk;
+
                io_sq_thread_park(sqd);
                /* Don't set affinity for a dying thread */
-               if (sqd->thread)
-                       ret = io_wq_cpu_affinity(sqd->thread->io_uring, mask);
+               tsk = sqpoll_task_locked(sqd);
+               if (tsk)
+                       ret = io_wq_cpu_affinity(tsk->io_uring, mask);
                io_sq_thread_unpark(sqd);
        }
 
index 4171666b1cf4cc37b84cb4079483bdf7b762add1..b83dcdec9765fd037da94d9b2f1173617b5a8652 100644 (file)
@@ -8,7 +8,7 @@ struct io_sq_data {
        /* ctx's that are using this sqd */
        struct list_head        ctx_list;
 
-       struct task_struct      *thread;
+       struct task_struct __rcu *thread;
        struct wait_queue_head  wait;
 
        unsigned                sq_thread_idle;
@@ -29,3 +29,9 @@ void io_sq_thread_unpark(struct io_sq_data *sqd);
 void io_put_sq_data(struct io_sq_data *sqd);
 void io_sqpoll_wait_sq(struct io_ring_ctx *ctx);
 int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx, cpumask_var_t mask);
+
+static inline struct task_struct *sqpoll_task_locked(struct io_sq_data *sqd)
+{
+       return rcu_dereference_protected(sqd->thread,
+                                        lockdep_is_held(&sqd->lock));
+}