]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
smb: smbdirect: introduce smbdirect_rw.c with server rw code
authorStefan Metzmacher <metze@samba.org>
Fri, 19 Sep 2025 07:07:03 +0000 (09:07 +0200)
committerSteve French <stfrench@microsoft.com>
Thu, 16 Apr 2026 02:58:19 +0000 (21:58 -0500)
This is basically contains the following functions copied from
the server: wait_for_rw_credits, calc_rw_credits, get_sg_list,
smb_direct_free_rdma_rw_msg, read_write_done, read_done,
write_done, smb_direct_rdma_xmit.

They got new names, some indentation/formatting changes,
some variable names are changed too.

They also only use struct smbdirect_socket instead of
struct smb_direct_transport.

But the logic is still the same. They will be used
by the server soon.

Cc: Steve French <smfrench@gmail.com>
Cc: Tom Talpey <tom@talpey.com>
Cc: Long Li <longli@microsoft.com>
Cc: Namjae Jeon <linkinjeon@kernel.org>
Cc: linux-cifs@vger.kernel.org
Cc: samba-technical@lists.samba.org
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Acked-by: Namjae Jeon <linkinjeon@kernel.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/smb/common/smbdirect/smbdirect_all_c_files.c
fs/smb/common/smbdirect/smbdirect_rw.c [new file with mode: 0644]
fs/smb/common/smbdirect/smbdirect_socket.h

index f1afc112075323390f0a0d961d3b1a6773f7c64f..963a1fc3b54b550dece2813c2f60d54937a11079 100644 (file)
@@ -18,3 +18,4 @@
 #include "smbdirect_socket.c"
 #include "smbdirect_connection.c"
 #include "smbdirect_mr.c"
+#include "smbdirect_rw.c"
diff --git a/fs/smb/common/smbdirect/smbdirect_rw.c b/fs/smb/common/smbdirect/smbdirect_rw.c
new file mode 100644 (file)
index 0000000..6eeec53
--- /dev/null
@@ -0,0 +1,255 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ *   Copyright (C) 2017, Microsoft Corporation.
+ *   Copyright (C) 2018, LG Electronics.
+ *   Copyright (c) 2025, Stefan Metzmacher
+ */
+
+#include "smbdirect_internal.h"
+
+static int smbdirect_connection_wait_for_rw_credits(struct smbdirect_socket *sc,
+                                                   int credits)
+{
+       return smbdirect_socket_wait_for_credits(sc,
+                                                SMBDIRECT_SOCKET_CONNECTED,
+                                                -ENOTCONN,
+                                                &sc->rw_io.credits.wait_queue,
+                                                &sc->rw_io.credits.count,
+                                                credits);
+}
+
+static int smbdirect_connection_calc_rw_credits(struct smbdirect_socket *sc,
+                                               const void *buf,
+                                               size_t len)
+{
+       return DIV_ROUND_UP(smbdirect_get_buf_page_count(buf, len),
+                           sc->rw_io.credits.num_pages);
+}
+
+static int smbdirect_connection_rdma_get_sg_list(void *buf,
+                                                size_t size,
+                                                struct scatterlist *sg_list,
+                                                size_t nentries)
+{
+       bool high = is_vmalloc_addr(buf);
+       struct page *page;
+       size_t offset, len;
+       int i = 0;
+
+       if (size == 0 || nentries < smbdirect_get_buf_page_count(buf, size))
+               return -EINVAL;
+
+       offset = offset_in_page(buf);
+       buf -= offset;
+       while (size > 0) {
+               len = min_t(size_t, PAGE_SIZE - offset, size);
+               if (high)
+                       page = vmalloc_to_page(buf);
+               else
+                       page = kmap_to_page(buf);
+
+               if (!sg_list)
+                       return -EINVAL;
+               sg_set_page(sg_list, page, len, offset);
+               sg_list = sg_next(sg_list);
+
+               buf += PAGE_SIZE;
+               size -= len;
+               offset = 0;
+               i++;
+       }
+
+       return i;
+}
+
+static void smbdirect_connection_rw_io_free(struct smbdirect_rw_io *msg,
+                                           enum dma_data_direction dir)
+{
+       struct smbdirect_socket *sc = msg->socket;
+
+       rdma_rw_ctx_destroy(&msg->rdma_ctx,
+                           sc->ib.qp,
+                           sc->ib.qp->port,
+                           msg->sgt.sgl,
+                           msg->sgt.nents,
+                           dir);
+       sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
+       kfree(msg);
+}
+
+static void smbdirect_connection_rdma_rw_done(struct ib_cq *cq, struct ib_wc *wc,
+                                             enum dma_data_direction dir)
+{
+       struct smbdirect_rw_io *msg =
+               container_of(wc->wr_cqe, struct smbdirect_rw_io, cqe);
+       struct smbdirect_socket *sc = msg->socket;
+
+       if (wc->status != IB_WC_SUCCESS) {
+               msg->error = -EIO;
+               pr_err("read/write error. opcode = %d, status = %s(%d)\n",
+                      wc->opcode, ib_wc_status_msg(wc->status), wc->status);
+               if (wc->status != IB_WC_WR_FLUSH_ERR)
+                       smbdirect_socket_schedule_cleanup(sc, msg->error);
+       }
+
+       complete(msg->completion);
+}
+
+static void smbdirect_connection_rdma_read_done(struct ib_cq *cq, struct ib_wc *wc)
+{
+       smbdirect_connection_rdma_rw_done(cq, wc, DMA_FROM_DEVICE);
+}
+
+static void smbdirect_connection_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc)
+{
+       smbdirect_connection_rdma_rw_done(cq, wc, DMA_TO_DEVICE);
+}
+
+__maybe_unused /* this is temporary while this file is included in others */
+static int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
+                                         void *buf, size_t buf_len,
+                                         struct smbdirect_buffer_descriptor_v1 *desc,
+                                         size_t desc_len,
+                                         bool is_read)
+{
+       const struct smbdirect_socket_parameters *sp = &sc->parameters;
+       enum dma_data_direction direction = is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
+       struct smbdirect_rw_io *msg, *next_msg;
+       size_t i;
+       int ret;
+       DECLARE_COMPLETION_ONSTACK(completion);
+       struct ib_send_wr *first_wr;
+       LIST_HEAD(msg_list);
+       u8 *desc_buf;
+       int credits_needed;
+       size_t desc_buf_len, desc_num = 0;
+
+       if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
+               return -ENOTCONN;
+
+       if (buf_len > sp->max_read_write_size)
+               return -EINVAL;
+
+       /* calculate needed credits */
+       credits_needed = 0;
+       desc_buf = buf;
+       for (i = 0; i < desc_len / sizeof(*desc); i++) {
+               if (!buf_len)
+                       break;
+
+               desc_buf_len = le32_to_cpu(desc[i].length);
+               if (!desc_buf_len)
+                       return -EINVAL;
+
+               if (desc_buf_len > buf_len) {
+                       desc_buf_len = buf_len;
+                       desc[i].length = cpu_to_le32(desc_buf_len);
+                       buf_len = 0;
+               }
+
+               credits_needed += smbdirect_connection_calc_rw_credits(sc,
+                                                                      desc_buf,
+                                                                      desc_buf_len);
+               desc_buf += desc_buf_len;
+               buf_len -= desc_buf_len;
+               desc_num++;
+       }
+
+       smbdirect_log_rdma_rw(sc, SMBDIRECT_LOG_INFO,
+               "RDMA %s, len %zu, needed credits %d\n",
+               str_read_write(is_read), buf_len, credits_needed);
+
+       ret = smbdirect_connection_wait_for_rw_credits(sc, credits_needed);
+       if (ret < 0)
+               return ret;
+
+       /* build rdma_rw_ctx for each descriptor */
+       desc_buf = buf;
+       for (i = 0; i < desc_num; i++) {
+               size_t page_count;
+
+               msg = kzalloc_flex(*msg, sg_list, SG_CHUNK_SIZE,
+                                  sc->rw_io.mem.gfp_mask);
+               if (!msg) {
+                       ret = -ENOMEM;
+                       goto out;
+               }
+
+               desc_buf_len = le32_to_cpu(desc[i].length);
+               page_count = smbdirect_get_buf_page_count(desc_buf, desc_buf_len);
+
+               msg->socket = sc;
+               msg->cqe.done = is_read ?
+                       smbdirect_connection_rdma_read_done :
+                       smbdirect_connection_rdma_write_done;
+               msg->completion = &completion;
+
+               msg->sgt.sgl = &msg->sg_list[0];
+               ret = sg_alloc_table_chained(&msg->sgt,
+                                            page_count,
+                                            msg->sg_list,
+                                            SG_CHUNK_SIZE);
+               if (ret) {
+                       ret = -ENOMEM;
+                       goto free_msg;
+               }
+
+               ret = smbdirect_connection_rdma_get_sg_list(desc_buf,
+                                                           desc_buf_len,
+                                                           msg->sgt.sgl,
+                                                           msg->sgt.orig_nents);
+               if (ret < 0)
+                       goto free_table;
+
+               ret = rdma_rw_ctx_init(&msg->rdma_ctx,
+                                      sc->ib.qp,
+                                      sc->ib.qp->port,
+                                      msg->sgt.sgl,
+                                      page_count,
+                                      0,
+                                      le64_to_cpu(desc[i].offset),
+                                      le32_to_cpu(desc[i].token),
+                                      direction);
+               if (ret < 0) {
+                       pr_err("failed to init rdma_rw_ctx: %d\n", ret);
+                       goto free_table;
+               }
+
+               list_add_tail(&msg->list, &msg_list);
+               desc_buf += desc_buf_len;
+       }
+
+       /* concatenate work requests of rdma_rw_ctxs */
+       first_wr = NULL;
+       list_for_each_entry_reverse(msg, &msg_list, list) {
+               first_wr = rdma_rw_ctx_wrs(&msg->rdma_ctx,
+                                          sc->ib.qp,
+                                          sc->ib.qp->port,
+                                          &msg->cqe,
+                                          first_wr);
+       }
+
+       ret = ib_post_send(sc->ib.qp, first_wr, NULL);
+       if (ret) {
+               pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
+               goto out;
+       }
+
+       msg = list_last_entry(&msg_list, struct smbdirect_rw_io, list);
+       wait_for_completion(&completion);
+       ret = msg->error;
+out:
+       list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
+               list_del(&msg->list);
+               smbdirect_connection_rw_io_free(msg, direction);
+       }
+       atomic_add(credits_needed, &sc->rw_io.credits.count);
+       wake_up(&sc->rw_io.credits.wait_queue);
+       return ret;
+
+free_table:
+       sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
+free_msg:
+       kfree(msg);
+       goto out;
+}
index b2882935a5d86bef7579ed5ac09033c3dc7d6117..36e6822c3795e1abab6154b5835e9a5946769381 100644 (file)
@@ -326,6 +326,14 @@ struct smbdirect_socket {
         * The state for RDMA read/write requests on the server
         */
        struct {
+               /*
+                * Memory hints for
+                * smbdirect_rw_io structs
+                */
+               struct {
+                       gfp_t gfp_mask;
+               } mem;
+
                /*
                 * The credit state for the send side
                 */
@@ -541,6 +549,7 @@ static __always_inline void smbdirect_socket_init(struct smbdirect_socket *sc)
        spin_lock_init(&sc->recv_io.reassembly.lock);
        init_waitqueue_head(&sc->recv_io.reassembly.wait_queue);
 
+       sc->rw_io.mem.gfp_mask = GFP_KERNEL;
        atomic_set(&sc->rw_io.credits.count, 0);
        init_waitqueue_head(&sc->rw_io.credits.wait_queue);