--- /dev/null
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ * SMB2 compression support for ksmbd.
+ *
+ * Receive and send SMB 3.1.1 compression transforms using the common helpers.
+ *
+ * Copyright (C) 2026 Namjae Jeon <linkinjeon@kernel.org>
+ */
+#include <linux/slab.h>
+
+#include "compress.h"
+#include "smb_common.h"
+
+/**
+ * ksmbd_decompress_request() - replace a compressed request with its SMB2 PDU
+ * @conn: connection which owns the current RFC1002 request buffer
+ *
+ * Derive the uncompressed size from the transform variant, enforce ksmbd's
+ * normal message limits, and ask the common decoder to validate every payload.
+ * On success, replace conn->request_buf with a regular RFC1002-framed SMB2
+ * message so the rest of the request path needs no compression awareness.
+ *
+ * Return: 0 on success, otherwise a negative errno.
+ */
+int ksmbd_decompress_request(struct ksmbd_conn *conn)
+{
+ struct smb2_compression_hdr *hdr;
+ unsigned int pdu_size = get_rfc1002_len(conn->request_buf);
+ u32 orig_size, offset, out_size;
+ u32 max_allowed_pdu_size;
+ char *buf, *out;
+ int rc;
+
+ if (pdu_size < sizeof(struct smb2_compression_hdr))
+ return -EINVAL;
+
+ if (conn->dialect != SMB311_PROT_ID ||
+ conn->compress_algorithm == SMB3_COMPRESS_NONE)
+ return -EINVAL;
+
+ hdr = smb_get_msg(conn->request_buf);
+ if (hdr->ProtocolId != SMB2_COMPRESSION_TRANSFORM_ID)
+ return -EINVAL;
+
+ orig_size = le32_to_cpu(hdr->OriginalCompressedSegmentSize);
+ if (hdr->Flags == cpu_to_le16(SMB2_COMPRESSION_FLAG_CHAINED)) {
+ out_size = orig_size;
+ } else {
+ offset = le32_to_cpu(hdr->Offset);
+ if (offset > pdu_size - sizeof(*hdr) ||
+ check_add_overflow(orig_size, offset, &out_size))
+ return -EINVAL;
+ }
+
+ max_allowed_pdu_size = SMB3_MAX_MSGSIZE + conn->vals->max_write_size;
+ if (out_size > max_allowed_pdu_size ||
+ out_size > MAX_STREAM_PROT_LEN)
+ return -EINVAL;
+
+ out = kvmalloc(out_size + 4 + 1, KSMBD_DEFAULT_GFP);
+ if (!out)
+ return -ENOMEM;
+
+ buf = (char *)hdr;
+ *(__be32 *)out = cpu_to_be32(out_size);
+ rc = smb_compression_decompress(conn->compress_algorithm,
+ conn->compress_chained,
+ buf, pdu_size, out + 4, out_size);
+ if (rc) {
+ kvfree(out);
+ return rc;
+ }
+
+ kvfree(conn->request_buf);
+ conn->request_buf = out;
+ return 0;
+}
#include "ndr.h"
#include "stats.h"
#include "transport_tcp.h"
+#include "compress.h"
static void __wbuf(struct ksmbd_work *work, void **req, void **rsp)
{
pneg_ctxt->Ciphers[0] = cipher_type;
}
+static void build_compress_ctxt(struct smb2_compression_capabilities_context *pneg_ctxt,
+ __le16 compress_algorithm, bool compress_chained,
+ bool compress_pattern)
+{
+ /*
+ * Return only algorithms implemented by ksmbd. Pattern_V1 is advertised
+ * as a second ID when the client also enabled chained transforms.
+ */
+ pneg_ctxt->ContextType = SMB2_COMPRESSION_CAPABILITIES;
+ pneg_ctxt->DataLength = cpu_to_le16(compress_pattern ? 12 : 10);
+ pneg_ctxt->Reserved = cpu_to_le32(0);
+ pneg_ctxt->CompressionAlgorithmCount =
+ cpu_to_le16(compress_pattern ? 2 : 1);
+ pneg_ctxt->Padding = cpu_to_le16(0);
+ pneg_ctxt->Flags = compress_chained ?
+ SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED :
+ SMB2_COMPRESSION_CAPABILITIES_FLAG_NONE;
+ pneg_ctxt->CompressionAlgorithms[0] = compress_algorithm;
+ pneg_ctxt->CompressionAlgorithms[1] = compress_pattern ?
+ SMB3_COMPRESS_PATTERN : 0;
+ pneg_ctxt->CompressionAlgorithms[2] = 0;
+ pneg_ctxt->CompressionAlgorithms[3] = 0;
+}
+
static void build_sign_cap_ctxt(struct smb2_signing_capabilities *pneg_ctxt,
__le16 sign_algo)
{
ctxt_size += sizeof(struct smb2_encryption_neg_context) + 2;
}
- /* compression context not yet supported */
- WARN_ON(conn->compress_algorithm != SMB3_COMPRESS_NONE);
+ if (conn->compress_algorithm != SMB3_COMPRESS_NONE) {
+ ctxt_size = round_up(ctxt_size, 8);
+ ksmbd_debug(SMB,
+ "assemble SMB2_COMPRESSION_CAPABILITIES context\n");
+ build_compress_ctxt((struct smb2_compression_capabilities_context *)
+ (pneg_ctxt + ctxt_size),
+ conn->compress_algorithm,
+ conn->compress_chained,
+ conn->compress_pattern);
+ neg_ctxt_cnt++;
+ ctxt_size += sizeof(struct smb2_neg_context) +
+ (conn->compress_pattern ? 12 : 10);
+ }
if (conn->posix_ext_supported) {
ctxt_size = round_up(ctxt_size, 8);
conn->cipher_type;
}
-static void decode_compress_ctxt(struct ksmbd_conn *conn,
- struct smb2_compression_capabilities_context *pneg_ctxt)
+static __le32 decode_compress_ctxt(struct ksmbd_conn *conn,
+ struct smb2_compression_capabilities_context *pneg_ctxt,
+ int ctxt_len)
{
+ int alg_cnt, algs_size, i;
+
+ if (sizeof(struct smb2_neg_context) + 10 > ctxt_len) {
+ pr_err("Invalid SMB2_COMPRESSION_CAPABILITIES context length\n");
+ return STATUS_INVALID_PARAMETER;
+ }
+
conn->compress_algorithm = SMB3_COMPRESS_NONE;
+ conn->compress_chained = false;
+ conn->compress_pattern = false;
+
+ alg_cnt = le16_to_cpu(pneg_ctxt->CompressionAlgorithmCount);
+ if (!alg_cnt)
+ return STATUS_INVALID_PARAMETER;
+
+ if (pneg_ctxt->Flags != SMB2_COMPRESSION_CAPABILITIES_FLAG_NONE &&
+ pneg_ctxt->Flags != SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED)
+ return STATUS_INVALID_PARAMETER;
+
+ algs_size = alg_cnt * sizeof(__le16);
+ if (sizeof(struct smb2_neg_context) + 8 + algs_size > ctxt_len) {
+ pr_err("Invalid compression algorithm count(%d)\n", alg_cnt);
+ return STATUS_INVALID_PARAMETER;
+ }
+
+ for (i = 0; i < alg_cnt; i++) {
+ __le16 alg = pneg_ctxt->CompressionAlgorithms[i];
+
+ /*
+ * LZ77 is the required general-purpose codec. Pattern_V1 is an
+ * optional chained payload type and cannot stand alone.
+ */
+ if (alg == SMB3_COMPRESS_LZ77) {
+ conn->compress_algorithm = alg;
+ conn->compress_chained =
+ pneg_ctxt->Flags ==
+ SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED;
+ ksmbd_debug(SMB, "Compression Algorithm ID = 0x%x\n",
+ le16_to_cpu(alg));
+ } else if (alg == SMB3_COMPRESS_PATTERN) {
+ conn->compress_pattern = true;
+ }
+ }
+
+ if (conn->compress_algorithm == SMB3_COMPRESS_NONE ||
+ !conn->compress_chained)
+ conn->compress_pattern = false;
+
+ return STATUS_SUCCESS;
}
static void decode_sign_cap_ctxt(struct ksmbd_conn *conn,
unsigned int offset = le32_to_cpu(req->NegotiateContextOffset);
unsigned int neg_ctxt_cnt = le16_to_cpu(req->NegotiateContextCount);
__le32 status = STATUS_INVALID_PARAMETER;
+ int compress_ctxt_cnt = 0;
ksmbd_debug(SMB, "decoding %d negotiate contexts\n", neg_ctxt_cnt);
if (len_of_smb <= offset) {
} else if (pctx->ContextType == SMB2_COMPRESSION_CAPABILITIES) {
ksmbd_debug(SMB,
"deassemble SMB2_COMPRESSION_CAPABILITIES context\n");
- if (conn->compress_algorithm)
+ if (compress_ctxt_cnt++) {
+ status = STATUS_INVALID_PARAMETER;
break;
+ }
- decode_compress_ctxt(conn,
- (struct smb2_compression_capabilities_context *)pctx);
+ status = decode_compress_ctxt(conn,
+ (struct smb2_compression_capabilities_context *)
+ pctx, ctxt_len);
+ if (status != STATUS_SUCCESS)
+ break;
} else if (pctx->ContextType == SMB2_NETNAME_NEGOTIATE_CONTEXT_ID) {
ksmbd_debug(SMB,
"deassemble SMB2_NETNAME_NEGOTIATE_CONTEXT_ID context\n");
rsp->Reserved = 0;
/* default manual caching */
rsp->ShareFlags = SMB2_SHAREFLAG_MANUAL_CACHING;
+ /* Tell the client that READ requests may request compressed responses. */
+ if (conn->dialect == SMB311_PROT_ID &&
+ conn->compress_algorithm != SMB3_COMPRESS_NONE)
+ rsp->ShareFlags |= cpu_to_le32(SMB2_SHAREFLAG_COMPRESS_DATA);
rc = ksmbd_iov_pin_rsp(work, rsp, sizeof(struct smb2_tree_connect_rsp));
if (rc)