]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
ublk: remove struct ublk_rq_data
authorCaleb Sander Mateos <csander@purestorage.com>
Fri, 20 Jun 2025 15:09:56 +0000 (09:09 -0600)
committerJens Axboe <axboe@kernel.dk>
Mon, 30 Jun 2025 21:50:53 +0000 (15:50 -0600)
__ublk_check_and_get_req() attempts to atomically look up the struct
request for a ublk I/O and take a reference on it. However, the request
can be freed between the lookup on the tagset in blk_mq_tag_to_rq() and
the increment of its reference count in ublk_get_req_ref(), for example
if an elevator switch happens concurrently.

Fix the potential use after free by moving the reference count from
ublk_rq_data to ublk_io. Move the fields buf_index and buf_ctx_handle
too to reduce the number of cache lines touched when dispatching and
completing a ublk I/O, allowing ublk_rq_data to be removed entirely.

Suggested-by: Ming Lei <ming.lei@redhat.com>
Signed-off-by: Caleb Sander Mateos <csander@purestorage.com>
Fixes: 62fe99cef94a ("ublk: add read()/write() support for ublk char device")
Reviewed-by: Ming Lei <ming.lei@redhat.com>
Link: https://lore.kernel.org/r/20250620151008.3976463-3-csander@purestorage.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
drivers/block/ublk_drv.c

index 9e48e0c1b0ccac4f68cc834487722712d05359a1..0ebf3bbb343fe2e61cf562f0526f62c544dc1a5a 100644 (file)
         UBLK_PARAM_TYPE_DEVT | UBLK_PARAM_TYPE_ZONED |    \
         UBLK_PARAM_TYPE_DMA_ALIGN | UBLK_PARAM_TYPE_SEGMENT)
 
-struct ublk_rq_data {
-       refcount_t ref;
-
-       /* for auto-unregister buffer in case of UBLK_F_AUTO_BUF_REG */
-       u16 buf_index;
-       void *buf_ctx_handle;
-};
-
 struct ublk_uring_cmd_pdu {
        /*
         * Store requests in same batch temporarily for queuing them to
@@ -169,6 +161,22 @@ struct ublk_io {
        };
 
        struct task_struct *task;
+
+       /*
+        * The number of uses of this I/O by the ublk server
+        * if user copy or zero copy are enabled:
+        * - 1 from dispatch to the server until UBLK_IO_COMMIT_AND_FETCH_REQ
+        * - 1 for each inflight ublk_ch_{read,write}_iter() call
+        * - 1 for each io_uring registered buffer
+        * The I/O can only be completed once all references are dropped.
+        * User copy and buffer registration operations are only permitted
+        * if the reference count is nonzero.
+        */
+       refcount_t ref;
+
+       /* auto-registered buffer, valid if UBLK_IO_FLAG_AUTO_BUF_REG is set */
+       u16 buf_index;
+       void *buf_ctx_handle;
 };
 
 struct ublk_queue {
@@ -228,7 +236,8 @@ static void ublk_io_release(void *priv);
 static void ublk_stop_dev_unlocked(struct ublk_device *ub);
 static void ublk_abort_queue(struct ublk_device *ub, struct ublk_queue *ubq);
 static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub,
-               const struct ublk_queue *ubq, int tag, size_t offset);
+               const struct ublk_queue *ubq, struct ublk_io *io,
+               size_t offset);
 static inline unsigned int ublk_req_build_flags(struct request *req);
 
 static inline struct ublksrv_io_desc *
@@ -673,34 +682,26 @@ static inline bool ublk_need_req_ref(const struct ublk_queue *ubq)
 }
 
 static inline void ublk_init_req_ref(const struct ublk_queue *ubq,
-               struct request *req)
+               struct ublk_io *io)
 {
-       if (ublk_need_req_ref(ubq)) {
-               struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
-
-               refcount_set(&data->ref, 1);
-       }
+       if (ublk_need_req_ref(ubq))
+               refcount_set(&io->ref, 1);
 }
 
 static inline bool ublk_get_req_ref(const struct ublk_queue *ubq,
-               struct request *req)
+               struct ublk_io *io)
 {
-       if (ublk_need_req_ref(ubq)) {
-               struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
-
-               return refcount_inc_not_zero(&data->ref);
-       }
+       if (ublk_need_req_ref(ubq))
+               return refcount_inc_not_zero(&io->ref);
 
        return true;
 }
 
 static inline void ublk_put_req_ref(const struct ublk_queue *ubq,
-               struct request *req)
+               struct ublk_io *io, struct request *req)
 {
        if (ublk_need_req_ref(ubq)) {
-               struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
-
-               if (refcount_dec_and_test(&data->ref))
+               if (refcount_dec_and_test(&io->ref))
                        __ublk_complete_rq(req);
        } else {
                __ublk_complete_rq(req);
@@ -1188,39 +1189,38 @@ static inline void __ublk_abort_rq(struct ublk_queue *ubq,
                blk_mq_end_request(rq, BLK_STS_IOERR);
 }
 
-static void ublk_auto_buf_reg_fallback(struct request *req)
+static void
+ublk_auto_buf_reg_fallback(const struct ublk_queue *ubq, struct ublk_io *io)
 {
-       const struct ublk_queue *ubq = req->mq_hctx->driver_data;
-       struct ublksrv_io_desc *iod = ublk_get_iod(ubq, req->tag);
-       struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
+       unsigned tag = io - ubq->ios;
+       struct ublksrv_io_desc *iod = ublk_get_iod(ubq, tag);
 
        iod->op_flags |= UBLK_IO_F_NEED_REG_BUF;
-       refcount_set(&data->ref, 1);
+       refcount_set(&io->ref, 1);
 }
 
-static bool ublk_auto_buf_reg(struct request *req, struct ublk_io *io,
-                             unsigned int issue_flags)
+static bool ublk_auto_buf_reg(const struct ublk_queue *ubq, struct request *req,
+                             struct ublk_io *io, unsigned int issue_flags)
 {
        struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(io->cmd);
-       struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
        int ret;
 
        ret = io_buffer_register_bvec(io->cmd, req, ublk_io_release,
                                      pdu->buf.index, issue_flags);
        if (ret) {
                if (pdu->buf.flags & UBLK_AUTO_BUF_REG_FALLBACK) {
-                       ublk_auto_buf_reg_fallback(req);
+                       ublk_auto_buf_reg_fallback(ubq, io);
                        return true;
                }
                blk_mq_end_request(req, BLK_STS_IOERR);
                return false;
        }
        /* one extra reference is dropped by ublk_io_release */
-       refcount_set(&data->ref, 2);
+       refcount_set(&io->ref, 2);
 
-       data->buf_ctx_handle = io_uring_cmd_ctx_handle(io->cmd);
+       io->buf_ctx_handle = io_uring_cmd_ctx_handle(io->cmd);
        /* store buffer index in request payload */
-       data->buf_index = pdu->buf.index;
+       io->buf_index = pdu->buf.index;
        io->flags |= UBLK_IO_FLAG_AUTO_BUF_REG;
        return true;
 }
@@ -1230,9 +1230,9 @@ static bool ublk_prep_auto_buf_reg(struct ublk_queue *ubq,
                                   unsigned int issue_flags)
 {
        if (ublk_support_auto_buf_reg(ubq) && ublk_rq_has_data(req))
-               return ublk_auto_buf_reg(req, io, issue_flags);
+               return ublk_auto_buf_reg(ubq, req, io, issue_flags);
 
-       ublk_init_req_ref(ubq, req);
+       ublk_init_req_ref(ubq, io);
        return true;
 }
 
@@ -1503,6 +1503,8 @@ static void ublk_queue_reinit(struct ublk_device *ub, struct ublk_queue *ubq)
                        put_task_struct(io->task);
                        io->task = NULL;
                }
+
+               WARN_ON_ONCE(refcount_read(&io->ref));
        }
 }
 
@@ -2007,12 +2009,14 @@ static void ublk_io_release(void *priv)
 {
        struct request *rq = priv;
        struct ublk_queue *ubq = rq->mq_hctx->driver_data;
+       struct ublk_io *io = &ubq->ios[rq->tag];
 
-       ublk_put_req_ref(ubq, rq);
+       ublk_put_req_ref(ubq, io, rq);
 }
 
 static int ublk_register_io_buf(struct io_uring_cmd *cmd,
-                               const struct ublk_queue *ubq, unsigned int tag,
+                               const struct ublk_queue *ubq,
+                               struct ublk_io *io,
                                unsigned int index, unsigned int issue_flags)
 {
        struct ublk_device *ub = cmd->file->private_data;
@@ -2022,14 +2026,14 @@ static int ublk_register_io_buf(struct io_uring_cmd *cmd,
        if (!ublk_support_zero_copy(ubq))
                return -EINVAL;
 
-       req = __ublk_check_and_get_req(ub, ubq, tag, 0);
+       req = __ublk_check_and_get_req(ub, ubq, io, 0);
        if (!req)
                return -EINVAL;
 
        ret = io_buffer_register_bvec(cmd, req, ublk_io_release, index,
                                      issue_flags);
        if (ret) {
-               ublk_put_req_ref(ubq, req);
+               ublk_put_req_ref(ubq, io, req);
                return ret;
        }
 
@@ -2136,10 +2140,8 @@ static int ublk_commit_and_fetch(const struct ublk_queue *ubq,
                 * this ublk request gets stuck.
                 */
                if (io->flags & UBLK_IO_FLAG_AUTO_BUF_REG) {
-                       struct ublk_rq_data *data = blk_mq_rq_to_pdu(req);
-
-                       if (data->buf_ctx_handle == io_uring_cmd_ctx_handle(cmd))
-                               io_buffer_unregister_bvec(cmd, data->buf_index,
+                       if (io->buf_ctx_handle == io_uring_cmd_ctx_handle(cmd))
+                               io_buffer_unregister_bvec(cmd, io->buf_index,
                                                issue_flags);
                        io->flags &= ~UBLK_IO_FLAG_AUTO_BUF_REG;
                }
@@ -2159,7 +2161,7 @@ static int ublk_commit_and_fetch(const struct ublk_queue *ubq,
                req->__sector = ub_cmd->zone_append_lba;
 
        if (likely(!blk_should_fake_timeout(req->q)))
-               ublk_put_req_ref(ubq, req);
+               ublk_put_req_ref(ubq, io, req);
 
        return 0;
 }
@@ -2238,7 +2240,7 @@ static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd,
        ret = -EINVAL;
        switch (_IOC_NR(cmd_op)) {
        case UBLK_IO_REGISTER_IO_BUF:
-               return ublk_register_io_buf(cmd, ubq, tag, ub_cmd->addr, issue_flags);
+               return ublk_register_io_buf(cmd, ubq, io, ub_cmd->addr, issue_flags);
        case UBLK_IO_UNREGISTER_IO_BUF:
                return ublk_unregister_io_buf(cmd, ubq, ub_cmd->addr, issue_flags);
        case UBLK_IO_FETCH_REQ:
@@ -2278,15 +2280,20 @@ static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd,
 }
 
 static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub,
-               const struct ublk_queue *ubq, int tag, size_t offset)
+               const struct ublk_queue *ubq, struct ublk_io *io, size_t offset)
 {
+       unsigned tag = io - ubq->ios;
        struct request *req;
 
+       /*
+        * can't use io->req in case of concurrent UBLK_IO_COMMIT_AND_FETCH_REQ,
+        * which would overwrite it with io->cmd
+        */
        req = blk_mq_tag_to_rq(ub->tag_set.tags[ubq->q_id], tag);
        if (!req)
                return NULL;
 
-       if (!ublk_get_req_ref(ubq, req))
+       if (!ublk_get_req_ref(ubq, io))
                return NULL;
 
        if (unlikely(!blk_mq_request_started(req) || req->tag != tag))
@@ -2300,7 +2307,7 @@ static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub,
 
        return req;
 fail_put:
-       ublk_put_req_ref(ubq, req);
+       ublk_put_req_ref(ubq, io, req);
        return NULL;
 }
 
@@ -2367,7 +2374,8 @@ static inline bool ublk_check_ubuf_dir(const struct request *req,
 }
 
 static struct request *ublk_check_and_get_req(struct kiocb *iocb,
-               struct iov_iter *iter, size_t *off, int dir)
+               struct iov_iter *iter, size_t *off, int dir,
+               struct ublk_io **io)
 {
        struct ublk_device *ub = iocb->ki_filp->private_data;
        struct ublk_queue *ubq;
@@ -2401,7 +2409,8 @@ static struct request *ublk_check_and_get_req(struct kiocb *iocb,
        if (tag >= ubq->q_depth)
                return ERR_PTR(-EINVAL);
 
-       req = __ublk_check_and_get_req(ub, ubq, tag, buf_off);
+       *io = &ubq->ios[tag];
+       req = __ublk_check_and_get_req(ub, ubq, *io, buf_off);
        if (!req)
                return ERR_PTR(-EINVAL);
 
@@ -2414,7 +2423,7 @@ static struct request *ublk_check_and_get_req(struct kiocb *iocb,
        *off = buf_off;
        return req;
 fail:
-       ublk_put_req_ref(ubq, req);
+       ublk_put_req_ref(ubq, *io, req);
        return ERR_PTR(-EACCES);
 }
 
@@ -2422,16 +2431,17 @@ static ssize_t ublk_ch_read_iter(struct kiocb *iocb, struct iov_iter *to)
 {
        struct ublk_queue *ubq;
        struct request *req;
+       struct ublk_io *io;
        size_t buf_off;
        size_t ret;
 
-       req = ublk_check_and_get_req(iocb, to, &buf_off, ITER_DEST);
+       req = ublk_check_and_get_req(iocb, to, &buf_off, ITER_DEST, &io);
        if (IS_ERR(req))
                return PTR_ERR(req);
 
        ret = ublk_copy_user_pages(req, buf_off, to, ITER_DEST);
        ubq = req->mq_hctx->driver_data;
-       ublk_put_req_ref(ubq, req);
+       ublk_put_req_ref(ubq, io, req);
 
        return ret;
 }
@@ -2440,16 +2450,17 @@ static ssize_t ublk_ch_write_iter(struct kiocb *iocb, struct iov_iter *from)
 {
        struct ublk_queue *ubq;
        struct request *req;
+       struct ublk_io *io;
        size_t buf_off;
        size_t ret;
 
-       req = ublk_check_and_get_req(iocb, from, &buf_off, ITER_SOURCE);
+       req = ublk_check_and_get_req(iocb, from, &buf_off, ITER_SOURCE, &io);
        if (IS_ERR(req))
                return PTR_ERR(req);
 
        ret = ublk_copy_user_pages(req, buf_off, from, ITER_SOURCE);
        ubq = req->mq_hctx->driver_data;
-       ublk_put_req_ref(ubq, req);
+       ublk_put_req_ref(ubq, io, req);
 
        return ret;
 }
@@ -2474,6 +2485,7 @@ static void ublk_deinit_queue(struct ublk_device *ub, int q_id)
                struct ublk_io *io = &ubq->ios[i];
                if (io->task)
                        put_task_struct(io->task);
+               WARN_ON_ONCE(refcount_read(&io->ref));
        }
 
        if (ubq->io_cmd_buf)
@@ -2626,7 +2638,6 @@ static int ublk_add_tag_set(struct ublk_device *ub)
        ub->tag_set.nr_hw_queues = ub->dev_info.nr_hw_queues;
        ub->tag_set.queue_depth = ub->dev_info.queue_depth;
        ub->tag_set.numa_node = NUMA_NO_NODE;
-       ub->tag_set.cmd_size = sizeof(struct ublk_rq_data);
        ub->tag_set.driver_data = ub;
        return blk_mq_alloc_tag_set(&ub->tag_set);
 }