]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
smb: client: add mid_counter_lock to protect the mid counter counter
authorWang Zhaolong <wangzhaolong@huaweicloud.com>
Mon, 4 Aug 2025 13:40:04 +0000 (21:40 +0800)
committerSteve French <stfrench@microsoft.com>
Tue, 5 Aug 2025 16:29:00 +0000 (11:29 -0500)
This is step 2/4 of a patch series to fix mid_q_entry memory leaks
caused by race conditions in callback execution.

Add a dedicated mid_counter_lock to protect current_mid counter,
separating it from mid_queue_lock which protects pending_mid_q
operations. This reduces lock contention and prepares for finer-
grained locking in subsequent patches.

Changes:
- Add TCP_Server_Info->mid_counter_lock spinlock
- Rename CurrentMid to current_mid for consistency
- Use mid_counter_lock to protect current_mid access
- Update locking documentation in cifsglob.h

This separation allows mid allocation to proceed without blocking
queue operations, improving performance under heavy load.

Signed-off-by: Wang Zhaolong <wangzhaolong@huaweicloud.com>
Acked-by: Enzo Matsumiya <ematsumiya@suse.de>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/smb/client/cifsglob.h
fs/smb/client/connect.c
fs/smb/client/smb1ops.c
fs/smb/client/smb2ops.c
fs/smb/client/transport.c

index 2dd1ef27425095b4b3eecff61621359d2c38a46a..cfba226f3396d15a1a27cfe575dc1e37dadf523a 100644 (file)
@@ -733,6 +733,7 @@ struct TCP_Server_Info {
        wait_queue_head_t response_q;
        wait_queue_head_t request_q; /* if more than maxmpx to srvr must block*/
        spinlock_t mid_queue_lock;  /* protect mid queue */
+       spinlock_t mid_counter_lock;
        struct list_head pending_mid_q;
        bool noblocksnd;                /* use blocking sendmsg */
        bool noautotune;                /* do not autotune send buf sizes */
@@ -770,7 +771,7 @@ struct TCP_Server_Info {
        /* SMB_COM_WRITE_RAW or SMB_COM_READ_RAW. */
        unsigned int capabilities; /* selective disabling of caps by smb sess */
        int timeAdj;  /* Adjust for difference in server time zone in sec */
-       __u64 CurrentMid;         /* multiplex id - rotating counter, protected by GlobalMid_Lock */
+       __u64 current_mid;      /* multiplex id - rotating counter, protected by mid_counter_lock */
        char cryptkey[CIFS_CRYPTO_KEY_SIZE]; /* used by ntlm, ntlmv2 etc */
        /* 16th byte of RFC1001 workstation name is always null */
        char workstation_RFC1001_name[RFC1001_NAME_LEN_WITH_NULL];
@@ -2008,8 +2009,8 @@ require use of the stronger protocol */
  *                             GlobalTotalActiveXid
  * TCP_Server_Info->srv_lock   (anything in struct not protected by another lock and can change)
  * TCP_Server_Info->mid_queue_lock     TCP_Server_Info->pending_mid_q  cifs_get_tcp_session
- *                             ->CurrentMid
  *                             (any changes in mid_q_entry fields)
+ * TCP_Server_Info->mid_counter_lock    TCP_Server_Info->current_mid    cifs_get_tcp_session
  * TCP_Server_Info->req_lock   TCP_Server_Info->in_flight      cifs_get_tcp_session
  *                             ->credits
  *                             ->echo_credits
index e4b577ca48d5956ffa8c7acb6659cd69b42adc39..74ad5881ee457104a7f9a7bedb97db16b0f7a683 100644 (file)
@@ -358,7 +358,7 @@ static bool cifs_tcp_ses_needs_reconnect(struct TCP_Server_Info *server, int num
        }
 
        cifs_dbg(FYI, "Mark tcp session as need reconnect\n");
-       trace_smb3_reconnect(server->CurrentMid, server->conn_id,
+       trace_smb3_reconnect(server->current_mid, server->conn_id,
                             server->hostname);
        server->tcpStatus = CifsNeedReconnect;
 
@@ -1242,7 +1242,7 @@ smb2_add_credits_from_hdr(char *buffer, struct TCP_Server_Info *server)
                spin_unlock(&server->req_lock);
                wake_up(&server->request_q);
 
-               trace_smb3_hdr_credits(server->CurrentMid,
+               trace_smb3_hdr_credits(server->current_mid,
                                server->conn_id, server->hostname, scredits,
                                le16_to_cpu(shdr->CreditRequest), in_flight);
                cifs_server_dbg(FYI, "%s: added %u credits total=%d\n",
@@ -1823,6 +1823,7 @@ cifs_get_tcp_session(struct smb3_fs_context *ctx,
        spin_lock_init(&tcp_ses->req_lock);
        spin_lock_init(&tcp_ses->srv_lock);
        spin_lock_init(&tcp_ses->mid_queue_lock);
+       spin_lock_init(&tcp_ses->mid_counter_lock);
        INIT_LIST_HEAD(&tcp_ses->tcp_ses_list);
        INIT_LIST_HEAD(&tcp_ses->smb_ses_list);
        INIT_DELAYED_WORK(&tcp_ses->echo, cifs_echo_request);
index e16566d3c319318be8ef5c8ccc9c6177a4c3949c..893a1ea8c000acb5617a8eb9b78444f0bdb8430e 100644 (file)
@@ -169,10 +169,9 @@ cifs_get_next_mid(struct TCP_Server_Info *server)
        __u16 last_mid, cur_mid;
        bool collision, reconnect = false;
 
-       spin_lock(&server->mid_queue_lock);
-
+       spin_lock(&server->mid_counter_lock);
        /* mid is 16 bit only for CIFS/SMB */
-       cur_mid = (__u16)((server->CurrentMid) & 0xffff);
+       cur_mid = (__u16)((server->current_mid) & 0xffff);
        /* we do not want to loop forever */
        last_mid = cur_mid;
        cur_mid++;
@@ -198,6 +197,7 @@ cifs_get_next_mid(struct TCP_Server_Info *server)
                        cur_mid++;
 
                num_mids = 0;
+               spin_lock(&server->mid_queue_lock);
                list_for_each_entry(mid_entry, &server->pending_mid_q, qhead) {
                        ++num_mids;
                        if (mid_entry->mid == cur_mid &&
@@ -207,6 +207,7 @@ cifs_get_next_mid(struct TCP_Server_Info *server)
                                break;
                        }
                }
+               spin_unlock(&server->mid_queue_lock);
 
                /*
                 * if we have more than 32k mids in the list, then something
@@ -223,12 +224,12 @@ cifs_get_next_mid(struct TCP_Server_Info *server)
 
                if (!collision) {
                        mid = (__u64)cur_mid;
-                       server->CurrentMid = mid;
+                       server->current_mid = mid;
                        break;
                }
                cur_mid++;
        }
-       spin_unlock(&server->mid_queue_lock);
+       spin_unlock(&server->mid_counter_lock);
 
        if (reconnect) {
                cifs_signal_cifsd_for_reconnect(server, false);
index 7935f9b433ac2c198c8e0f6112ccec109ca8a0e9..ebaeb2993569f4ec63f8e88ce4c0177a3a72bd4a 100644 (file)
@@ -91,7 +91,7 @@ smb2_add_credits(struct TCP_Server_Info *server,
        if (*val > 65000) {
                *val = 65000; /* Don't get near 64K credits, avoid srv bugs */
                pr_warn_once("server overflowed SMB3 credits\n");
-               trace_smb3_overflow_credits(server->CurrentMid,
+               trace_smb3_overflow_credits(server->current_mid,
                                            server->conn_id, server->hostname, *val,
                                            add, server->in_flight);
        }
@@ -136,7 +136,7 @@ smb2_add_credits(struct TCP_Server_Info *server,
        wake_up(&server->request_q);
 
        if (reconnect_detected) {
-               trace_smb3_reconnect_detected(server->CurrentMid,
+               trace_smb3_reconnect_detected(server->current_mid,
                        server->conn_id, server->hostname, scredits, add, in_flight);
 
                cifs_dbg(FYI, "trying to put %d credits from the old server instance %d\n",
@@ -144,7 +144,7 @@ smb2_add_credits(struct TCP_Server_Info *server,
        }
 
        if (reconnect_with_invalid_credits) {
-               trace_smb3_reconnect_with_invalid_credits(server->CurrentMid,
+               trace_smb3_reconnect_with_invalid_credits(server->current_mid,
                        server->conn_id, server->hostname, scredits, add, in_flight);
                cifs_dbg(FYI, "Negotiate operation when server credits is non-zero. Optype: %d, server credits: %d, credits added: %d\n",
                         optype, scredits, add);
@@ -176,7 +176,7 @@ smb2_add_credits(struct TCP_Server_Info *server,
                break;
        }
 
-       trace_smb3_add_credits(server->CurrentMid,
+       trace_smb3_add_credits(server->current_mid,
                        server->conn_id, server->hostname, scredits, add, in_flight);
        cifs_dbg(FYI, "%s: added %u credits total=%d\n", __func__, add, scredits);
 }
@@ -203,7 +203,7 @@ smb2_set_credits(struct TCP_Server_Info *server, const int val)
        in_flight = server->in_flight;
        spin_unlock(&server->req_lock);
 
-       trace_smb3_set_credits(server->CurrentMid,
+       trace_smb3_set_credits(server->current_mid,
                        server->conn_id, server->hostname, scredits, val, in_flight);
        cifs_dbg(FYI, "%s: set %u credits\n", __func__, val);
 
@@ -288,7 +288,7 @@ smb2_wait_mtu_credits(struct TCP_Server_Info *server, size_t size,
        in_flight = server->in_flight;
        spin_unlock(&server->req_lock);
 
-       trace_smb3_wait_credits(server->CurrentMid,
+       trace_smb3_wait_credits(server->current_mid,
                        server->conn_id, server->hostname, scredits, -(credits->value), in_flight);
        cifs_dbg(FYI, "%s: removed %u credits total=%d\n",
                        __func__, credits->value, scredits);
@@ -316,7 +316,7 @@ smb2_adjust_credits(struct TCP_Server_Info *server,
                                      server->credits, server->in_flight,
                                      new_val - credits->value,
                                      cifs_trace_rw_credits_no_adjust_up);
-               trace_smb3_too_many_credits(server->CurrentMid,
+               trace_smb3_too_many_credits(server->current_mid,
                                server->conn_id, server->hostname, 0, credits->value - new_val, 0);
                cifs_server_dbg(VFS, "R=%x[%x] request has less credits (%d) than required (%d)",
                                subreq->rreq->debug_id, subreq->subreq.debug_index,
@@ -338,7 +338,7 @@ smb2_adjust_credits(struct TCP_Server_Info *server,
                                      server->credits, server->in_flight,
                                      new_val - credits->value,
                                      cifs_trace_rw_credits_old_session);
-               trace_smb3_reconnect_detected(server->CurrentMid,
+               trace_smb3_reconnect_detected(server->current_mid,
                        server->conn_id, server->hostname, scredits,
                        credits->value - new_val, in_flight);
                cifs_server_dbg(VFS, "R=%x[%x] trying to return %d credits to old session\n",
@@ -358,7 +358,7 @@ smb2_adjust_credits(struct TCP_Server_Info *server,
        spin_unlock(&server->req_lock);
        wake_up(&server->request_q);
 
-       trace_smb3_adj_credits(server->CurrentMid,
+       trace_smb3_adj_credits(server->current_mid,
                        server->conn_id, server->hostname, scredits,
                        credits->value - new_val, in_flight);
        cifs_dbg(FYI, "%s: adjust added %u credits total=%d\n",
@@ -374,19 +374,19 @@ smb2_get_next_mid(struct TCP_Server_Info *server)
 {
        __u64 mid;
        /* for SMB2 we need the current value */
-       spin_lock(&server->mid_queue_lock);
-       mid = server->CurrentMid++;
-       spin_unlock(&server->mid_queue_lock);
+       spin_lock(&server->mid_counter_lock);
+       mid = server->current_mid++;
+       spin_unlock(&server->mid_counter_lock);
        return mid;
 }
 
 static void
 smb2_revert_current_mid(struct TCP_Server_Info *server, const unsigned int val)
 {
-       spin_lock(&server->mid_queue_lock);
-       if (server->CurrentMid >= val)
-               server->CurrentMid -= val;
-       spin_unlock(&server->mid_queue_lock);
+       spin_lock(&server->mid_counter_lock);
+       if (server->current_mid >= val)
+               server->current_mid -= val;
+       spin_unlock(&server->mid_counter_lock);
 }
 
 static struct mid_q_entry *
@@ -460,9 +460,9 @@ smb2_negotiate(const unsigned int xid,
 {
        int rc;
 
-       spin_lock(&server->mid_queue_lock);
-       server->CurrentMid = 0;
-       spin_unlock(&server->mid_queue_lock);
+       spin_lock(&server->mid_counter_lock);
+       server->current_mid = 0;
+       spin_unlock(&server->mid_counter_lock);
        rc = SMB2_negotiate(xid, ses, server);
        return rc;
 }
@@ -2498,7 +2498,7 @@ smb2_is_status_pending(char *buf, struct TCP_Server_Info *server)
                spin_unlock(&server->req_lock);
                wake_up(&server->request_q);
 
-               trace_smb3_pend_credits(server->CurrentMid,
+               trace_smb3_pend_credits(server->current_mid,
                                server->conn_id, server->hostname, scredits,
                                le16_to_cpu(shdr->CreditRequest), in_flight);
                cifs_dbg(FYI, "%s: status pending add %u credits total=%d\n",
index 12dc927aa4a2011910cec29ed8a2d2d310293f4f..8037accc3987b6f749677c4c29e1bd1a7ac67ba2 100644 (file)
@@ -397,7 +397,7 @@ unmask:
                 * socket so the server throws away the partial SMB
                 */
                cifs_signal_cifsd_for_reconnect(server, false);
-               trace_smb3_partial_send_reconnect(server->CurrentMid,
+               trace_smb3_partial_send_reconnect(server->current_mid,
                                                  server->conn_id, server->hostname);
        }
 smbd_done:
@@ -509,7 +509,7 @@ wait_for_free_credits(struct TCP_Server_Info *server, const int num_credits,
                in_flight = server->in_flight;
                spin_unlock(&server->req_lock);
 
-               trace_smb3_nblk_credits(server->CurrentMid,
+               trace_smb3_nblk_credits(server->current_mid,
                                server->conn_id, server->hostname, scredits, -1, in_flight);
                cifs_dbg(FYI, "%s: remove %u credits total=%d\n",
                                __func__, 1, scredits);
@@ -542,7 +542,7 @@ wait_for_free_credits(struct TCP_Server_Info *server, const int num_credits,
                                in_flight = server->in_flight;
                                spin_unlock(&server->req_lock);
 
-                               trace_smb3_credit_timeout(server->CurrentMid,
+                               trace_smb3_credit_timeout(server->current_mid,
                                                server->conn_id, server->hostname, scredits,
                                                num_credits, in_flight);
                                cifs_server_dbg(VFS, "wait timed out after %d ms\n",
@@ -585,7 +585,7 @@ wait_for_free_credits(struct TCP_Server_Info *server, const int num_credits,
                                        spin_unlock(&server->req_lock);
 
                                        trace_smb3_credit_timeout(
-                                                       server->CurrentMid,
+                                                       server->current_mid,
                                                        server->conn_id, server->hostname,
                                                        scredits, num_credits, in_flight);
                                        cifs_server_dbg(VFS, "wait timed out after %d ms\n",
@@ -615,7 +615,7 @@ wait_for_free_credits(struct TCP_Server_Info *server, const int num_credits,
                        in_flight = server->in_flight;
                        spin_unlock(&server->req_lock);
 
-                       trace_smb3_waitff_credits(server->CurrentMid,
+                       trace_smb3_waitff_credits(server->current_mid,
                                        server->conn_id, server->hostname, scredits,
                                        -(num_credits), in_flight);
                        cifs_dbg(FYI, "%s: remove %u credits total=%d\n",
@@ -666,7 +666,7 @@ wait_for_compound_request(struct TCP_Server_Info *server, int num,
                 */
                if (server->in_flight == 0) {
                        spin_unlock(&server->req_lock);
-                       trace_smb3_insufficient_credits(server->CurrentMid,
+                       trace_smb3_insufficient_credits(server->current_mid,
                                        server->conn_id, server->hostname, scredits,
                                        num, in_flight);
                        cifs_dbg(FYI, "%s: %d requests in flight, needed %d total=%d\n",