]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
RDMA/rxe: Enable asynchronous prefetch for ODP MRs
authorDaisuke Matsuda <dskmtsd@gmail.com>
Thu, 22 May 2025 11:19:55 +0000 (11:19 +0000)
committerLeon Romanovsky <leon@kernel.org>
Thu, 12 Jun 2025 08:09:42 +0000 (04:09 -0400)
Calling ibv_advise_mr(3) with flags other than IBV_ADVISE_MR_FLAG_FLUSH
invokes an asynchronous request. It is best-effort, and thus can safely be
deferred to the system-wide workqueue.

The reference counter in rxe_mr is used to ensure that the MRs persist and
that rxe is not terminated until the queued work is done.

Signed-off-by: Daisuke Matsuda <dskmtsd@gmail.com>
Link: https://patch.msgid.link/20250522111955.3227-3-dskmtsd@gmail.com
Signed-off-by: Leon Romanovsky <leon@kernel.org>
drivers/infiniband/sw/rxe/rxe_odp.c

index c0413181acc2f3eb9ff5c5ff2abbeb58b4bd32b0..6313680e9d40d47f566b8bf021b5f967769b5988 100644 (file)
@@ -419,6 +419,52 @@ enum resp_states rxe_odp_do_atomic_write(struct rxe_mr *mr, u64 iova, u64 value)
        return RESPST_NONE;
 }
 
+struct prefetch_mr_work {
+       struct work_struct work;
+       u32 pf_flags;
+       u32 num_sge;
+       struct {
+               u64 io_virt;
+               struct rxe_mr *mr;
+               size_t length;
+       } frags[];
+};
+
+static void rxe_ib_prefetch_mr_work(struct work_struct *w)
+{
+       struct prefetch_mr_work *work =
+               container_of(w, struct prefetch_mr_work, work);
+       int ret;
+       u32 i;
+
+       /*
+        * We rely on IB/core that work is executed
+        * if we have num_sge != 0 only.
+        */
+       WARN_ON(!work->num_sge);
+       for (i = 0; i < work->num_sge; ++i) {
+               struct ib_umem_odp *umem_odp;
+
+               ret = rxe_odp_do_pagefault_and_lock(work->frags[i].mr,
+                                                   work->frags[i].io_virt,
+                                                   work->frags[i].length,
+                                                   work->pf_flags);
+               if (ret < 0) {
+                       rxe_dbg_mr(work->frags[i].mr,
+                                  "failed to prefetch the mr\n");
+                       goto deref;
+               }
+
+               umem_odp = to_ib_umem_odp(work->frags[i].mr->umem);
+               mutex_unlock(&umem_odp->umem_mutex);
+
+deref:
+               rxe_put(work->frags[i].mr);
+       }
+
+       kvfree(work);
+}
+
 static int rxe_ib_prefetch_sg_list(struct ib_pd *ibpd,
                                   enum ib_uverbs_advise_mr_advice advice,
                                   u32 pf_flags, struct ib_sge *sg_list,
@@ -470,7 +516,11 @@ static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
                                     u32 flags, struct ib_sge *sg_list,
                                     u32 num_sge)
 {
+       struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
        u32 pf_flags = RXE_PAGEFAULT_DEFAULT;
+       struct prefetch_mr_work *work;
+       struct rxe_mr *mr;
+       u32 i;
 
        if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
                pf_flags |= RXE_PAGEFAULT_RDONLY;
@@ -483,8 +533,41 @@ static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
                return rxe_ib_prefetch_sg_list(ibpd, advice, pf_flags, sg_list,
                                               num_sge);
 
-       /* Asynchronous call is to be added in the next patch */
-       return -EOPNOTSUPP;
+       /* Asynchronous call is "best-effort" and allowed to fail */
+       work = kvzalloc(struct_size(work, frags, num_sge), GFP_KERNEL);
+       if (!work)
+               return -ENOMEM;
+
+       INIT_WORK(&work->work, rxe_ib_prefetch_mr_work);
+       work->pf_flags = pf_flags;
+       work->num_sge = num_sge;
+
+       for (i = 0; i < num_sge; ++i) {
+               /* Takes a reference, which will be released in the queued work */
+               mr = lookup_mr(pd, IB_ACCESS_LOCAL_WRITE,
+                              sg_list[i].lkey, RXE_LOOKUP_LOCAL);
+               if (IS_ERR(mr))
+                       goto err;
+
+               work->frags[i].io_virt = sg_list[i].addr;
+               work->frags[i].length = sg_list[i].length;
+               work->frags[i].mr = mr;
+       }
+
+       queue_work(system_unbound_wq, &work->work);
+
+       return 0;
+
+ err:
+       /* rollback reference counts for the invalid request */
+       while (i > 0) {
+               i--;
+               rxe_put(work->frags[i].mr);
+       }
+
+       kvfree(work);
+
+       return PTR_ERR(mr);
 }
 
 int rxe_ib_advise_mr(struct ib_pd *ibpd,