From: Stefan Metzmacher Date: Fri, 14 Nov 2025 14:41:02 +0000 (+0100) Subject: smb: server: make use of smbdirect_socket_{listen,accept}() X-Git-Tag: v7.1-rc1~128^2~15 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2eff5e51f97663ad2371115260884396718b5e92;p=thirdparty%2Fkernel%2Flinux.git smb: server: make use of smbdirect_socket_{listen,accept}() We no longer need the custom rdma listener. The code logic is very similar to transport_tcp.c now using a kernel thread that loops over smbdirect_socket_accept(). This is the first step in the direction of using IPPROTO_SMBDIRECT sockets in future. Cc: Namjae Jeon Cc: Steve French Cc: Tom Talpey Cc: linux-cifs@vger.kernel.org Cc: samba-technical@lists.samba.org Signed-off-by: Stefan Metzmacher Acked-by: Namjae Jeon Signed-off-by: Steve French --- diff --git a/fs/smb/server/transport_rdma.c b/fs/smb/server/transport_rdma.c index e58d7e89da0ef..7171bde9d0788 100644 --- a/fs/smb/server/transport_rdma.c +++ b/fs/smb/server/transport_rdma.c @@ -89,7 +89,10 @@ struct smb_direct_device { static struct smb_direct_listener { int port; - struct rdma_cm_id *cm_id; + + struct task_struct *thread; + + struct smbdirect_socket *socket; } smb_direct_ib_listener, smb_direct_iw_listener; static struct workqueue_struct *smb_direct_wq; @@ -185,49 +188,15 @@ unsigned int get_smbd_max_read_write_size(struct ksmbd_transport *kt) return sp->max_read_write_size; } -static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id) +static struct smb_direct_transport *alloc_transport(struct smbdirect_socket *sc) { struct smb_direct_transport *t; - struct smbdirect_socket *sc; - struct smbdirect_socket_parameters init_params = {}; - struct smbdirect_socket_parameters *sp; struct ksmbd_conn *conn; - int ret; - - /* - * Create the initial parameters - */ - sp = &init_params; - sp->negotiate_timeout_msec = SMB_DIRECT_NEGOTIATE_TIMEOUT * 1000; - sp->initiator_depth = SMB_DIRECT_CM_INITIATOR_DEPTH; - sp->responder_resources = 1; - sp->recv_credit_max = smb_direct_receive_credit_max; - sp->send_credit_target = smb_direct_send_credit_target; - sp->max_send_size = smb_direct_max_send_size; - sp->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size; - sp->max_recv_size = smb_direct_max_receive_size; - sp->max_read_write_size = smb_direct_max_read_write_size; - sp->keepalive_interval_msec = SMB_DIRECT_KEEPALIVE_SEND_INTERVAL * 1000; - sp->keepalive_timeout_msec = SMB_DIRECT_KEEPALIVE_RECV_TIMEOUT * 1000; t = kzalloc_obj(*t, KSMBD_DEFAULT_GFP); if (!t) return NULL; - ret = smbdirect_socket_create_accepting(cm_id, &sc); - if (ret) - goto socket_create_failed; - smbdirect_socket_set_logging(sc, NULL, - smb_direct_logging_needed, - smb_direct_logging_vaprintf); - ret = smbdirect_socket_set_initial_parameters(sc, sp); - if (ret) - goto set_params_failed; - ret = smbdirect_socket_set_kernel_settings(sc, IB_POLL_WORKQUEUE, KSMBD_DEFAULT_GFP); - if (ret) - goto set_settings_failed; - ret = smbdirect_socket_set_custom_workqueue(sc, smb_direct_wq); - if (ret) - goto set_workqueue_failed; + t->socket = sc; conn = ksmbd_conn_alloc(); if (!conn) @@ -241,15 +210,9 @@ static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id) KSMBD_TRANS(t)->conn = conn; KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops; - t->socket = sc; return t; conn_alloc_failed: -set_workqueue_failed: -set_settings_failed: -set_params_failed: - smbdirect_socket_release(sc); -socket_create_failed: kfree(t); return NULL; } @@ -346,48 +309,18 @@ static void smb_direct_shutdown(struct ksmbd_transport *t) smbdirect_socket_shutdown(sc); } -static int smb_direct_prepare(struct ksmbd_transport *t) -{ - struct smb_direct_transport *st = SMBD_TRANS(t); - struct smbdirect_socket *sc = st->socket; - int ret; - - ksmbd_debug(RDMA, "SMB_DIRECT Waiting for connection\n"); - ret = smbdirect_connection_wait_for_connected(sc); - if (ret) { - ksmbd_debug(RDMA, "SMB_DIRECT connection failed %d => %1pe\n", - ret, ERR_PTR(ret)); - return ret; - } - - ksmbd_debug(RDMA, "SMB_DIRECT connection ready\n"); - return 0; -} - -static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id, - struct rdma_cm_event *event) +static int smb_direct_new_connection(struct smb_direct_listener *listener, + struct smbdirect_socket *client_sc) { - struct smb_direct_listener *listener = new_cm_id->context; struct smb_direct_transport *t; - struct smbdirect_socket *sc; struct task_struct *handler; int ret; - if (!smbdirect_frwr_is_supported(&new_cm_id->device->attrs)) { - ksmbd_debug(RDMA, - "Fast Registration Work Requests is not supported. device capabilities=%llx\n", - new_cm_id->device->attrs.device_cap_flags); - return -EPROTONOSUPPORT; - } - - t = alloc_transport(new_cm_id); - if (!t) + t = alloc_transport(client_sc); + if (!t) { + smbdirect_socket_release(client_sc); return -ENOMEM; - sc = t->socket; - - ret = smbdirect_accept_connect_request(sc, &event->param.conn); - if (ret) - goto out_err; + } handler = kthread_run(ksmbd_conn_handler_loop, KSMBD_TRANS(t)->conn, "ksmbd:r%u", @@ -404,41 +337,68 @@ out_err: return ret; } -static int smb_direct_listen_handler(struct rdma_cm_id *cm_id, - struct rdma_cm_event *event) +static int smb_direct_listener_kthread_fn(void *p) { - switch (event->event) { - case RDMA_CM_EVENT_CONNECT_REQUEST: { - int ret = smb_direct_handle_connect_request(cm_id, event); + struct smb_direct_listener *listener = (struct smb_direct_listener *)p; + struct smbdirect_socket *client_sc = NULL; - if (ret) { - pr_err("Can't create transport: %d\n", ret); - return ret; - } + while (!kthread_should_stop()) { + struct proto_accept_arg arg = { .err = -EINVAL, }; + long timeo = MAX_SCHEDULE_TIMEOUT; - ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n", - cm_id); - break; - } - default: - pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n", - cm_id, rdma_event_msg(event->event), event->event); - break; + if (!listener->socket) + break; + client_sc = smbdirect_socket_accept(listener->socket, timeo, &arg); + if (!client_sc && arg.err == -EINVAL) + break; + if (!client_sc) + continue; + + ksmbd_debug(CONN, "connect success: accepted new connection\n"); + smb_direct_new_connection(listener, client_sc); } + + ksmbd_debug(CONN, "releasing socket\n"); return 0; } +static void smb_direct_listener_destroy(struct smb_direct_listener *listener) +{ + int ret; + + if (listener->socket) + smbdirect_socket_shutdown(listener->socket); + + if (listener->thread) { + ret = kthread_stop(listener->thread); + if (ret) + pr_err("failed to stop forker thread\n"); + listener->thread = NULL; + } + + if (listener->socket) { + smbdirect_socket_release(listener->socket); + listener->socket = NULL; + } + + listener->port = 0; +} + static int smb_direct_listen(struct smb_direct_listener *listener, int port) { - int ret; - struct rdma_cm_id *cm_id; - u8 node_type = RDMA_NODE_UNSPECIFIED; + struct net *net = current->nsproxy->net_ns; + struct task_struct *kthread; struct sockaddr_in sin = { .sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_ANY), .sin_port = htons(port), }; + struct smbdirect_socket_parameters init_params = {}; + struct smbdirect_socket_parameters *sp; + struct smbdirect_socket *sc; + u64 port_flags = 0; + int ret; switch (port) { case SMB_DIRECT_PORT_IWARP: @@ -446,7 +406,7 @@ static int smb_direct_listen(struct smb_direct_listener *listener, * only allow iWarp devices * for port 5445. */ - node_type = RDMA_NODE_RNIC; + port_flags |= SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW; break; case SMB_DIRECT_PORT_INFINIBAND: /* @@ -455,47 +415,90 @@ static int smb_direct_listen(struct smb_direct_listener *listener, * * (Basically don't allow iWarp devices) */ - node_type = RDMA_NODE_IB_CA; + port_flags |= SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB; break; default: pr_err("unsupported smbdirect port=%d!\n", port); return -ENODEV; } - cm_id = rdma_create_id(&init_net, smb_direct_listen_handler, - listener, RDMA_PS_TCP, IB_QPT_RC); - if (IS_ERR(cm_id)) { - pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id)); - return PTR_ERR(cm_id); + ret = smbdirect_socket_create_kern(net, &sc); + if (ret) { + pr_err("smbdirect_socket_create_kern() failed: %d %1pe\n", + ret, ERR_PTR(ret)); + return ret; } - ret = rdma_restrict_node_type(cm_id, node_type); + /* + * Create the initial parameters + */ + sp = &init_params; + sp->flags |= port_flags; + sp->negotiate_timeout_msec = SMB_DIRECT_NEGOTIATE_TIMEOUT * 1000; + sp->initiator_depth = SMB_DIRECT_CM_INITIATOR_DEPTH; + sp->responder_resources = 1; + sp->recv_credit_max = smb_direct_receive_credit_max; + sp->send_credit_target = smb_direct_send_credit_target; + sp->max_send_size = smb_direct_max_send_size; + sp->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size; + sp->max_recv_size = smb_direct_max_receive_size; + sp->max_read_write_size = smb_direct_max_read_write_size; + sp->keepalive_interval_msec = SMB_DIRECT_KEEPALIVE_SEND_INTERVAL * 1000; + sp->keepalive_timeout_msec = SMB_DIRECT_KEEPALIVE_RECV_TIMEOUT * 1000; + + smbdirect_socket_set_logging(sc, NULL, + smb_direct_logging_needed, + smb_direct_logging_vaprintf); + ret = smbdirect_socket_set_initial_parameters(sc, sp); if (ret) { - pr_err("rdma_restrict_node_type(%u) failed %d\n", - node_type, ret); + pr_err("Failed smbdirect_socket_set_initial_parameters(): %d %1pe\n", + ret, ERR_PTR(ret)); + goto err; + } + ret = smbdirect_socket_set_kernel_settings(sc, IB_POLL_WORKQUEUE, KSMBD_DEFAULT_GFP); + if (ret) { + pr_err("Failed smbdirect_socket_set_kernel_settings(): %d %1pe\n", + ret, ERR_PTR(ret)); + goto err; + } + ret = smbdirect_socket_set_custom_workqueue(sc, smb_direct_wq); + if (ret) { + pr_err("Failed smbdirect_socket_set_custom_workqueue(): %d %1pe\n", + ret, ERR_PTR(ret)); goto err; } - ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin); + ret = smbdirect_socket_bind(sc, (struct sockaddr *)&sin); if (ret) { - pr_err("Can't bind: %d\n", ret); + pr_err("smbdirect_socket_bind() failed: %d %1pe\n", + ret, ERR_PTR(ret)); goto err; } - ret = rdma_listen(cm_id, 10); + ret = smbdirect_socket_listen(sc, 10); if (ret) { - pr_err("Can't listen: %d\n", ret); + pr_err("Port[%d] smbdirect_socket_listen() failed: %d %1pe\n", + port, ret, ERR_PTR(ret)); goto err; } listener->port = port; - listener->cm_id = cm_id; + listener->socket = sc; + + kthread = kthread_run(smb_direct_listener_kthread_fn, + listener, + "ksmbd-smbdirect-listener-%u", port); + if (IS_ERR(kthread)) { + ret = PTR_ERR(kthread); + pr_err("Can't start ksmbd listen kthread: %d %1pe\n", + ret, ERR_PTR(ret)); + goto err; + } + listener->thread = kthread; return 0; err: - listener->port = 0; - listener->cm_id = NULL; - rdma_destroy_id(cm_id); + smb_direct_listener_destroy(listener); return ret; } @@ -546,7 +549,7 @@ int ksmbd_rdma_init(void) int ret; smb_direct_ib_listener = smb_direct_iw_listener = (struct smb_direct_listener) { - .cm_id = NULL, + .socket = NULL, }; ret = ib_register_client(&smb_direct_ib_client); @@ -575,8 +578,8 @@ int ksmbd_rdma_init(void) goto err; } - ksmbd_debug(RDMA, "InfiniBand/RoCEv1/RoCEv2 RDMA listener. cm_id=%p\n", - smb_direct_ib_listener.cm_id); + ksmbd_debug(RDMA, "InfiniBand/RoCEv1/RoCEv2 RDMA listener. socket=%p\n", + smb_direct_ib_listener.socket); ret = smb_direct_listen(&smb_direct_iw_listener, SMB_DIRECT_PORT_IWARP); @@ -585,8 +588,8 @@ int ksmbd_rdma_init(void) goto err; } - ksmbd_debug(RDMA, "iWarp RDMA listener. cm_id=%p\n", - smb_direct_iw_listener.cm_id); + ksmbd_debug(RDMA, "iWarp RDMA listener. socket=%p\n", + smb_direct_iw_listener.socket); return 0; err: @@ -597,19 +600,13 @@ err: void ksmbd_rdma_stop_listening(void) { - if (!smb_direct_ib_listener.cm_id && !smb_direct_iw_listener.cm_id) + if (!smb_direct_ib_listener.socket && !smb_direct_iw_listener.socket) return; ib_unregister_client(&smb_direct_ib_client); - if (smb_direct_ib_listener.cm_id) - rdma_destroy_id(smb_direct_ib_listener.cm_id); - if (smb_direct_iw_listener.cm_id) - rdma_destroy_id(smb_direct_iw_listener.cm_id); - - smb_direct_ib_listener = smb_direct_iw_listener = (struct smb_direct_listener) { - .cm_id = NULL, - }; + smb_direct_listener_destroy(&smb_direct_ib_listener); + smb_direct_listener_destroy(&smb_direct_iw_listener); } void ksmbd_rdma_destroy(void) @@ -685,7 +682,6 @@ bool ksmbd_rdma_capable_netdev(struct net_device *netdev) } static const struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = { - .prepare = smb_direct_prepare, .disconnect = smb_direct_disconnect, .shutdown = smb_direct_shutdown, .writev = smb_direct_writev,