]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
nvme-auth: common: use crypto library in nvme_auth_derive_tls_psk()
authorEric Biggers <ebiggers@kernel.org>
Mon, 2 Mar 2026 07:59:50 +0000 (23:59 -0800)
committerKeith Busch <kbusch@kernel.org>
Fri, 27 Mar 2026 14:35:01 +0000 (07:35 -0700)
For the HKDF-Expand-Label computation in nvme_auth_derive_tls_psk(), use
the crypto library instead of crypto_shash and crypto/hkdf.c.

While this means the HKDF "helper" functions are no longer utilized,
they clearly weren't buying us much: it's simpler to just inline the
HMAC computations directly, and this code needs to be tested anyway.  (A
similar result was seen in fs/crypto/.  As a result, this eliminates the
last user of crypto/hkdf.c, which we'll be able to remove as well.)

As usual this is also a lot more efficient, eliminating the allocation
of a transformation object and multiple other dynamic allocations.

Acked-by: Ard Biesheuvel <ardb@kernel.org>
Acked-by: Christoph Hellwig <hch@lst.de>
Reviewed-by: Hannes Reinecke <hare@suse.de>
Signed-off-by: Eric Biggers <ebiggers@kernel.org>
Signed-off-by: Keith Busch <kbusch@kernel.org>
drivers/nvme/common/auth.c

index f0b4e1c6ade7e491c4f8e70deca1335e7f8a8580..5be86629c2d411d1c3040bc9fb8858231117750d 100644 (file)
@@ -9,9 +9,7 @@
 #include <linux/prandom.h>
 #include <linux/scatterlist.h>
 #include <linux/unaligned.h>
-#include <crypto/hash.h>
 #include <crypto/dh.h>
-#include <crypto/hkdf.h>
 #include <crypto/sha2.h>
 #include <linux/nvme.h>
 #include <linux/nvme-auth.h>
@@ -621,59 +619,6 @@ out:
 }
 EXPORT_SYMBOL_GPL(nvme_auth_generate_digest);
 
-/**
- * hkdf_expand_label - HKDF-Expand-Label (RFC 8846 section 7.1)
- * @hmac_tfm: hash context keyed with pseudorandom key
- * @label: ASCII label without "tls13 " prefix
- * @labellen: length of @label
- * @context: context bytes
- * @contextlen: length of @context
- * @okm: output keying material
- * @okmlen: length of @okm
- *
- * Build the TLS 1.3 HkdfLabel structure and invoke hkdf_expand().
- *
- * Returns 0 on success with output keying material stored in @okm,
- * or a negative errno value otherwise.
- */
-static int hkdf_expand_label(struct crypto_shash *hmac_tfm,
-               const u8 *label, unsigned int labellen,
-               const u8 *context, unsigned int contextlen,
-               u8 *okm, unsigned int okmlen)
-{
-       int err;
-       u8 *info;
-       unsigned int infolen;
-       const char *tls13_prefix = "tls13 ";
-       unsigned int prefixlen = strlen(tls13_prefix);
-
-       if (WARN_ON(labellen > (255 - prefixlen)))
-               return -EINVAL;
-       if (WARN_ON(contextlen > 255))
-               return -EINVAL;
-
-       infolen = 2 + (1 + prefixlen + labellen) + (1 + contextlen);
-       info = kzalloc(infolen, GFP_KERNEL);
-       if (!info)
-               return -ENOMEM;
-
-       /* HkdfLabel.Length */
-       put_unaligned_be16(okmlen, info);
-
-       /* HkdfLabel.Label */
-       info[2] = prefixlen + labellen;
-       memcpy(info + 3, tls13_prefix, prefixlen);
-       memcpy(info + 3 + prefixlen, label, labellen);
-
-       /* HkdfLabel.Context */
-       info[3 + prefixlen + labellen] = contextlen;
-       memcpy(info + 4 + prefixlen + labellen, context, contextlen);
-
-       err = hkdf_expand(hmac_tfm, info, infolen, okm, okmlen);
-       kfree_sensitive(info);
-       return err;
-}
-
 /**
  * nvme_auth_derive_tls_psk - Derive TLS PSK
  * @hmac_id: Hash function identifier
@@ -704,84 +649,89 @@ static int hkdf_expand_label(struct crypto_shash *hmac_tfm,
 int nvme_auth_derive_tls_psk(int hmac_id, const u8 *psk, size_t psk_len,
                             const char *psk_digest, u8 **ret_psk)
 {
-       struct crypto_shash *hmac_tfm;
-       const char *hmac_name;
-       const char *label = "nvme-tls-psk";
        static const u8 default_salt[NVME_AUTH_MAX_DIGEST_SIZE];
-       size_t prk_len;
-       const char *ctx;
-       u8 *prk, *tls_key;
+       static const char label[] = "tls13 nvme-tls-psk";
+       const size_t label_len = sizeof(label) - 1;
+       u8 prk[NVME_AUTH_MAX_DIGEST_SIZE];
+       size_t hash_len, ctx_len;
+       u8 *hmac_data = NULL, *tls_key;
+       size_t i;
        int ret;
 
-       hmac_name = nvme_auth_hmac_name(hmac_id);
-       if (!hmac_name) {
+       hash_len = nvme_auth_hmac_hash_len(hmac_id);
+       if (hash_len == 0) {
                pr_warn("%s: invalid hash algorithm %d\n",
                        __func__, hmac_id);
                return -EINVAL;
        }
        if (hmac_id == NVME_AUTH_HASH_SHA512) {
                pr_warn("%s: unsupported hash algorithm %s\n",
-                       __func__, hmac_name);
+                       __func__, nvme_auth_hmac_name(hmac_id));
                return -EINVAL;
        }
 
-       if (psk_len != nvme_auth_hmac_hash_len(hmac_id)) {
+       if (psk_len != hash_len) {
                pr_warn("%s: unexpected psk_len %zu\n", __func__, psk_len);
                return -EINVAL;
        }
 
-       hmac_tfm = crypto_alloc_shash(hmac_name, 0, 0);
-       if (IS_ERR(hmac_tfm))
-               return PTR_ERR(hmac_tfm);
+       /* HKDF-Extract */
+       ret = nvme_auth_hmac(hmac_id, default_salt, hash_len, psk, psk_len,
+                            prk);
+       if (ret)
+               goto out;
+
+       /*
+        * HKDF-Expand-Label (RFC 8446 section 7.1), with output length equal to
+        * the hash length (so only a single HMAC operation is needed)
+        */
 
-       prk_len = crypto_shash_digestsize(hmac_tfm);
-       prk = kzalloc(prk_len, GFP_KERNEL);
-       if (!prk) {
+       hmac_data = kmalloc(/* output length */ 2 +
+                           /* label */ 1 + label_len +
+                           /* context (max) */ 1 + 3 + 1 + strlen(psk_digest) +
+                           /* counter */ 1,
+                           GFP_KERNEL);
+       if (!hmac_data) {
                ret = -ENOMEM;
-               goto out_free_shash;
+               goto out;
        }
-
-       if (WARN_ON(prk_len > NVME_AUTH_MAX_DIGEST_SIZE)) {
+       /* output length */
+       i = 0;
+       hmac_data[i++] = hash_len >> 8;
+       hmac_data[i++] = hash_len;
+
+       /* label */
+       static_assert(label_len <= 255);
+       hmac_data[i] = label_len;
+       memcpy(&hmac_data[i + 1], label, label_len);
+       i += 1 + label_len;
+
+       /* context */
+       ctx_len = sprintf(&hmac_data[i + 1], "%02d %s", hmac_id, psk_digest);
+       if (ctx_len > 255) {
                ret = -EINVAL;
-               goto out_free_prk;
+               goto out;
        }
-       ret = hkdf_extract(hmac_tfm, psk, psk_len,
-                          default_salt, prk_len, prk);
-       if (ret)
-               goto out_free_prk;
+       hmac_data[i] = ctx_len;
+       i += 1 + ctx_len;
 
-       ret = crypto_shash_setkey(hmac_tfm, prk, prk_len);
-       if (ret)
-               goto out_free_prk;
-
-       ctx = kasprintf(GFP_KERNEL, "%02d %s", hmac_id, psk_digest);
-       if (!ctx) {
-               ret = -ENOMEM;
-               goto out_free_prk;
-       }
+       /* counter (this overwrites the NUL terminator written by sprintf) */
+       hmac_data[i++] = 1;
 
        tls_key = kzalloc(psk_len, GFP_KERNEL);
        if (!tls_key) {
                ret = -ENOMEM;
-               goto out_free_ctx;
+               goto out;
        }
-       ret = hkdf_expand_label(hmac_tfm,
-                               label, strlen(label),
-                               ctx, strlen(ctx),
-                               tls_key, psk_len);
+       ret = nvme_auth_hmac(hmac_id, prk, hash_len, hmac_data, i, tls_key);
        if (ret) {
-               kfree(tls_key);
-               goto out_free_ctx;
+               kfree_sensitive(tls_key);
+               goto out;
        }
        *ret_psk = tls_key;
-
-out_free_ctx:
-       kfree(ctx);
-out_free_prk:
-       kfree(prk);
-out_free_shash:
-       crypto_free_shash(hmac_tfm);
-
+out:
+       kfree_sensitive(hmac_data);
+       memzero_explicit(prk, sizeof(prk));
        return ret;
 }
 EXPORT_SYMBOL_GPL(nvme_auth_derive_tls_psk);