]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
smb: server: make use of smbdirect_socket_{listen,accept}()
authorStefan Metzmacher <metze@samba.org>
Fri, 14 Nov 2025 14:41:02 +0000 (15:41 +0100)
committerSteve French <stfrench@microsoft.com>
Thu, 16 Apr 2026 02:58:24 +0000 (21:58 -0500)
We no longer need the custom rdma listener.

The code logic is very similar to transport_tcp.c now
using a kernel thread that loops over smbdirect_socket_accept().

This is the first step in the direction of using IPPROTO_SMBDIRECT
sockets in future.

Cc: Namjae Jeon <linkinjeon@kernel.org>
Cc: Steve French <smfrench@gmail.com>
Cc: Tom Talpey <tom@talpey.com>
Cc: linux-cifs@vger.kernel.org
Cc: samba-technical@lists.samba.org
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Acked-by: Namjae Jeon <linkinjeon@kernel.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/smb/server/transport_rdma.c

index e58d7e89da0ef937e9710cfc73719aa01df3bd91..7171bde9d078828d2f5fc99b3d1618c7407ee788 100644 (file)
@@ -89,7 +89,10 @@ struct smb_direct_device {
 
 static struct smb_direct_listener {
        int                     port;
-       struct rdma_cm_id       *cm_id;
+
+       struct task_struct      *thread;
+
+       struct smbdirect_socket *socket;
 } smb_direct_ib_listener, smb_direct_iw_listener;
 
 static struct workqueue_struct *smb_direct_wq;
@@ -185,49 +188,15 @@ unsigned int get_smbd_max_read_write_size(struct ksmbd_transport *kt)
        return sp->max_read_write_size;
 }
 
-static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
+static struct smb_direct_transport *alloc_transport(struct smbdirect_socket *sc)
 {
        struct smb_direct_transport *t;
-       struct smbdirect_socket *sc;
-       struct smbdirect_socket_parameters init_params = {};
-       struct smbdirect_socket_parameters *sp;
        struct ksmbd_conn *conn;
-       int ret;
-
-       /*
-        * Create the initial parameters
-        */
-       sp = &init_params;
-       sp->negotiate_timeout_msec = SMB_DIRECT_NEGOTIATE_TIMEOUT * 1000;
-       sp->initiator_depth = SMB_DIRECT_CM_INITIATOR_DEPTH;
-       sp->responder_resources = 1;
-       sp->recv_credit_max = smb_direct_receive_credit_max;
-       sp->send_credit_target = smb_direct_send_credit_target;
-       sp->max_send_size = smb_direct_max_send_size;
-       sp->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
-       sp->max_recv_size = smb_direct_max_receive_size;
-       sp->max_read_write_size = smb_direct_max_read_write_size;
-       sp->keepalive_interval_msec = SMB_DIRECT_KEEPALIVE_SEND_INTERVAL * 1000;
-       sp->keepalive_timeout_msec = SMB_DIRECT_KEEPALIVE_RECV_TIMEOUT * 1000;
 
        t = kzalloc_obj(*t, KSMBD_DEFAULT_GFP);
        if (!t)
                return NULL;
-       ret = smbdirect_socket_create_accepting(cm_id, &sc);
-       if (ret)
-               goto socket_create_failed;
-       smbdirect_socket_set_logging(sc, NULL,
-                                    smb_direct_logging_needed,
-                                    smb_direct_logging_vaprintf);
-       ret = smbdirect_socket_set_initial_parameters(sc, sp);
-       if (ret)
-               goto set_params_failed;
-       ret = smbdirect_socket_set_kernel_settings(sc, IB_POLL_WORKQUEUE, KSMBD_DEFAULT_GFP);
-       if (ret)
-               goto set_settings_failed;
-       ret = smbdirect_socket_set_custom_workqueue(sc, smb_direct_wq);
-       if (ret)
-               goto set_workqueue_failed;
+       t->socket = sc;
 
        conn = ksmbd_conn_alloc();
        if (!conn)
@@ -241,15 +210,9 @@ static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
        KSMBD_TRANS(t)->conn = conn;
        KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops;
 
-       t->socket = sc;
        return t;
 
 conn_alloc_failed:
-set_workqueue_failed:
-set_settings_failed:
-set_params_failed:
-       smbdirect_socket_release(sc);
-socket_create_failed:
        kfree(t);
        return NULL;
 }
@@ -346,48 +309,18 @@ static void smb_direct_shutdown(struct ksmbd_transport *t)
        smbdirect_socket_shutdown(sc);
 }
 
-static int smb_direct_prepare(struct ksmbd_transport *t)
-{
-       struct smb_direct_transport *st = SMBD_TRANS(t);
-       struct smbdirect_socket *sc = st->socket;
-       int ret;
-
-       ksmbd_debug(RDMA, "SMB_DIRECT Waiting for connection\n");
-       ret = smbdirect_connection_wait_for_connected(sc);
-       if (ret) {
-               ksmbd_debug(RDMA, "SMB_DIRECT connection failed %d => %1pe\n",
-                           ret, ERR_PTR(ret));
-               return ret;
-       }
-
-       ksmbd_debug(RDMA, "SMB_DIRECT connection ready\n");
-       return 0;
-}
-
-static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id,
-                                            struct rdma_cm_event *event)
+static int smb_direct_new_connection(struct smb_direct_listener *listener,
+                                    struct smbdirect_socket *client_sc)
 {
-       struct smb_direct_listener *listener = new_cm_id->context;
        struct smb_direct_transport *t;
-       struct smbdirect_socket *sc;
        struct task_struct *handler;
        int ret;
 
-       if (!smbdirect_frwr_is_supported(&new_cm_id->device->attrs)) {
-               ksmbd_debug(RDMA,
-                           "Fast Registration Work Requests is not supported. device capabilities=%llx\n",
-                           new_cm_id->device->attrs.device_cap_flags);
-               return -EPROTONOSUPPORT;
-       }
-
-       t = alloc_transport(new_cm_id);
-       if (!t)
+       t = alloc_transport(client_sc);
+       if (!t) {
+               smbdirect_socket_release(client_sc);
                return -ENOMEM;
-       sc = t->socket;
-
-       ret = smbdirect_accept_connect_request(sc, &event->param.conn);
-       if (ret)
-               goto out_err;
+       }
 
        handler = kthread_run(ksmbd_conn_handler_loop,
                              KSMBD_TRANS(t)->conn, "ksmbd:r%u",
@@ -404,41 +337,68 @@ out_err:
        return ret;
 }
 
-static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
-                                    struct rdma_cm_event *event)
+static int smb_direct_listener_kthread_fn(void *p)
 {
-       switch (event->event) {
-       case RDMA_CM_EVENT_CONNECT_REQUEST: {
-               int ret = smb_direct_handle_connect_request(cm_id, event);
+       struct smb_direct_listener *listener = (struct smb_direct_listener *)p;
+       struct smbdirect_socket *client_sc = NULL;
 
-               if (ret) {
-                       pr_err("Can't create transport: %d\n", ret);
-                       return ret;
-               }
+       while (!kthread_should_stop()) {
+               struct proto_accept_arg arg = { .err = -EINVAL, };
+               long timeo = MAX_SCHEDULE_TIMEOUT;
 
-               ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n",
-                           cm_id);
-               break;
-       }
-       default:
-               pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n",
-                      cm_id, rdma_event_msg(event->event), event->event);
-               break;
+               if (!listener->socket)
+                       break;
+               client_sc = smbdirect_socket_accept(listener->socket, timeo, &arg);
+               if (!client_sc && arg.err == -EINVAL)
+                       break;
+               if (!client_sc)
+                       continue;
+
+               ksmbd_debug(CONN, "connect success: accepted new connection\n");
+               smb_direct_new_connection(listener, client_sc);
        }
+
+       ksmbd_debug(CONN, "releasing socket\n");
        return 0;
 }
 
+static void smb_direct_listener_destroy(struct smb_direct_listener *listener)
+{
+       int ret;
+
+       if (listener->socket)
+               smbdirect_socket_shutdown(listener->socket);
+
+       if (listener->thread) {
+               ret = kthread_stop(listener->thread);
+               if (ret)
+                       pr_err("failed to stop forker thread\n");
+               listener->thread = NULL;
+       }
+
+       if (listener->socket) {
+               smbdirect_socket_release(listener->socket);
+               listener->socket = NULL;
+       }
+
+       listener->port = 0;
+}
+
 static int smb_direct_listen(struct smb_direct_listener *listener,
                             int port)
 {
-       int ret;
-       struct rdma_cm_id *cm_id;
-       u8 node_type = RDMA_NODE_UNSPECIFIED;
+       struct net *net = current->nsproxy->net_ns;
+       struct task_struct *kthread;
        struct sockaddr_in sin = {
                .sin_family             = AF_INET,
                .sin_addr.s_addr        = htonl(INADDR_ANY),
                .sin_port               = htons(port),
        };
+       struct smbdirect_socket_parameters init_params = {};
+       struct smbdirect_socket_parameters *sp;
+       struct smbdirect_socket *sc;
+       u64 port_flags = 0;
+       int ret;
 
        switch (port) {
        case SMB_DIRECT_PORT_IWARP:
@@ -446,7 +406,7 @@ static int smb_direct_listen(struct smb_direct_listener *listener,
                 * only allow iWarp devices
                 * for port 5445.
                 */
-               node_type = RDMA_NODE_RNIC;
+               port_flags |= SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW;
                break;
        case SMB_DIRECT_PORT_INFINIBAND:
                /*
@@ -455,47 +415,90 @@ static int smb_direct_listen(struct smb_direct_listener *listener,
                 *
                 * (Basically don't allow iWarp devices)
                 */
-               node_type = RDMA_NODE_IB_CA;
+               port_flags |= SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB;
                break;
        default:
                pr_err("unsupported smbdirect port=%d!\n", port);
                return -ENODEV;
        }
 
-       cm_id = rdma_create_id(&init_net, smb_direct_listen_handler,
-                              listener, RDMA_PS_TCP, IB_QPT_RC);
-       if (IS_ERR(cm_id)) {
-               pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id));
-               return PTR_ERR(cm_id);
+       ret = smbdirect_socket_create_kern(net, &sc);
+       if (ret) {
+               pr_err("smbdirect_socket_create_kern() failed: %d %1pe\n",
+                      ret, ERR_PTR(ret));
+               return ret;
        }
 
-       ret = rdma_restrict_node_type(cm_id, node_type);
+       /*
+        * Create the initial parameters
+        */
+       sp = &init_params;
+       sp->flags |= port_flags;
+       sp->negotiate_timeout_msec = SMB_DIRECT_NEGOTIATE_TIMEOUT * 1000;
+       sp->initiator_depth = SMB_DIRECT_CM_INITIATOR_DEPTH;
+       sp->responder_resources = 1;
+       sp->recv_credit_max = smb_direct_receive_credit_max;
+       sp->send_credit_target = smb_direct_send_credit_target;
+       sp->max_send_size = smb_direct_max_send_size;
+       sp->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
+       sp->max_recv_size = smb_direct_max_receive_size;
+       sp->max_read_write_size = smb_direct_max_read_write_size;
+       sp->keepalive_interval_msec = SMB_DIRECT_KEEPALIVE_SEND_INTERVAL * 1000;
+       sp->keepalive_timeout_msec = SMB_DIRECT_KEEPALIVE_RECV_TIMEOUT * 1000;
+
+       smbdirect_socket_set_logging(sc, NULL,
+                                    smb_direct_logging_needed,
+                                    smb_direct_logging_vaprintf);
+       ret = smbdirect_socket_set_initial_parameters(sc, sp);
        if (ret) {
-               pr_err("rdma_restrict_node_type(%u) failed %d\n",
-                      node_type, ret);
+               pr_err("Failed smbdirect_socket_set_initial_parameters(): %d %1pe\n",
+                      ret, ERR_PTR(ret));
+               goto err;
+       }
+       ret = smbdirect_socket_set_kernel_settings(sc, IB_POLL_WORKQUEUE, KSMBD_DEFAULT_GFP);
+       if (ret) {
+               pr_err("Failed smbdirect_socket_set_kernel_settings(): %d %1pe\n",
+                      ret, ERR_PTR(ret));
+               goto err;
+       }
+       ret = smbdirect_socket_set_custom_workqueue(sc, smb_direct_wq);
+       if (ret) {
+               pr_err("Failed smbdirect_socket_set_custom_workqueue(): %d %1pe\n",
+                      ret, ERR_PTR(ret));
                goto err;
        }
 
-       ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin);
+       ret = smbdirect_socket_bind(sc, (struct sockaddr *)&sin);
        if (ret) {
-               pr_err("Can't bind: %d\n", ret);
+               pr_err("smbdirect_socket_bind() failed: %d %1pe\n",
+                      ret, ERR_PTR(ret));
                goto err;
        }
 
-       ret = rdma_listen(cm_id, 10);
+       ret = smbdirect_socket_listen(sc, 10);
        if (ret) {
-               pr_err("Can't listen: %d\n", ret);
+               pr_err("Port[%d] smbdirect_socket_listen() failed: %d %1pe\n",
+                      port, ret, ERR_PTR(ret));
                goto err;
        }
 
        listener->port = port;
-       listener->cm_id = cm_id;
+       listener->socket = sc;
+
+       kthread = kthread_run(smb_direct_listener_kthread_fn,
+                             listener,
+                             "ksmbd-smbdirect-listener-%u", port);
+       if (IS_ERR(kthread)) {
+               ret = PTR_ERR(kthread);
+               pr_err("Can't start ksmbd listen kthread: %d %1pe\n",
+                      ret, ERR_PTR(ret));
+               goto err;
+       }
 
+       listener->thread = kthread;
        return 0;
 err:
-       listener->port = 0;
-       listener->cm_id = NULL;
-       rdma_destroy_id(cm_id);
+       smb_direct_listener_destroy(listener);
        return ret;
 }
 
@@ -546,7 +549,7 @@ int ksmbd_rdma_init(void)
        int ret;
 
        smb_direct_ib_listener = smb_direct_iw_listener = (struct smb_direct_listener) {
-               .cm_id = NULL,
+               .socket = NULL,
        };
 
        ret = ib_register_client(&smb_direct_ib_client);
@@ -575,8 +578,8 @@ int ksmbd_rdma_init(void)
                goto err;
        }
 
-       ksmbd_debug(RDMA, "InfiniBand/RoCEv1/RoCEv2 RDMA listener. cm_id=%p\n",
-                   smb_direct_ib_listener.cm_id);
+       ksmbd_debug(RDMA, "InfiniBand/RoCEv1/RoCEv2 RDMA listener. socket=%p\n",
+                   smb_direct_ib_listener.socket);
 
        ret = smb_direct_listen(&smb_direct_iw_listener,
                                SMB_DIRECT_PORT_IWARP);
@@ -585,8 +588,8 @@ int ksmbd_rdma_init(void)
                goto err;
        }
 
-       ksmbd_debug(RDMA, "iWarp RDMA listener. cm_id=%p\n",
-                   smb_direct_iw_listener.cm_id);
+       ksmbd_debug(RDMA, "iWarp RDMA listener. socket=%p\n",
+                   smb_direct_iw_listener.socket);
 
        return 0;
 err:
@@ -597,19 +600,13 @@ err:
 
 void ksmbd_rdma_stop_listening(void)
 {
-       if (!smb_direct_ib_listener.cm_id && !smb_direct_iw_listener.cm_id)
+       if (!smb_direct_ib_listener.socket && !smb_direct_iw_listener.socket)
                return;
 
        ib_unregister_client(&smb_direct_ib_client);
 
-       if (smb_direct_ib_listener.cm_id)
-               rdma_destroy_id(smb_direct_ib_listener.cm_id);
-       if (smb_direct_iw_listener.cm_id)
-               rdma_destroy_id(smb_direct_iw_listener.cm_id);
-
-       smb_direct_ib_listener = smb_direct_iw_listener = (struct smb_direct_listener) {
-               .cm_id = NULL,
-       };
+       smb_direct_listener_destroy(&smb_direct_ib_listener);
+       smb_direct_listener_destroy(&smb_direct_iw_listener);
 }
 
 void ksmbd_rdma_destroy(void)
@@ -685,7 +682,6 @@ bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
 }
 
 static const struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
-       .prepare        = smb_direct_prepare,
        .disconnect     = smb_direct_disconnect,
        .shutdown       = smb_direct_shutdown,
        .writev         = smb_direct_writev,