]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
ml-dsa: allow signature operations to be provided a μ value
authorPauli <ppzgs1@gmail.com>
Wed, 5 Feb 2025 03:06:04 +0000 (14:06 +1100)
committerTomas Mraz <tomas@openssl.org>
Fri, 14 Feb 2025 09:46:04 +0000 (10:46 +0100)
The μ value replaces the message and avoids some of the preliminary
processes.  This is part of FIPS 204.

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/26637)

crypto/ml_dsa/ml_dsa_sign.c
include/crypto/ml_dsa.h
providers/implementations/keymgmt/ml_dsa_kmgmt.c
providers/implementations/signature/ml_dsa_sig.c

index 5b997f6a9f36051efc3bfec8bee0d1c7813ccf4e..7a3b6a6f78127761c8b2ed2ee4b4e4f94e3aee83 100644 (file)
@@ -46,7 +46,7 @@ static void signature_init(ML_DSA_SIG *sig,
  * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
  * @returns 1 on success and 0 on failure.
  */
-static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
+static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
                                 const uint8_t *encoded_msg,
                                 size_t encoded_msg_len,
                                 const uint8_t *rnd, size_t rnd_len,
@@ -67,7 +67,8 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
     VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
     MATRIX a_ntt;
     ML_DSA_SIG sig;
-    uint8_t mu[ML_DSA_MU_BYTES];
+    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
+    const size_t mu_len = sizeof(mu);
     uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
     size_t c_tilde_len = params->bit_strength >> 2;
@@ -107,12 +108,20 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
     signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
     /* End of the allocated blob setup */
 
-    if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt)
-            || !shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
-                            encoded_msg, encoded_msg_len, mu, sizeof(mu))
-            || !shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
-                            rnd, rnd_len, mu, sizeof(mu),
-                            rho_prime, sizeof(rho_prime)))
+    if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
+        goto err;
+    if (msg_is_mu) {
+        if (encoded_msg_len != mu_len)
+            goto err;
+        mu_ptr = (uint8_t *)encoded_msg;
+    } else {
+        if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
+                         encoded_msg, encoded_msg_len, mu_ptr, mu_len))
+            goto err;
+    }
+    if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
+                     rnd, rnd_len, mu_ptr, mu_len,
+                     rho_prime, sizeof(rho_prime)))
         goto err;
 
     vector_copy(&s1_ntt, &priv->s1);
@@ -143,7 +152,7 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
         vector_high_bits(&w, gamma2, &w1);
         ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
 
-        if (!shake_xof_2(md_ctx, priv->shake256_md, mu, sizeof(mu),
+        if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
                          w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
             break;
 
@@ -195,7 +204,7 @@ err:
 /*
  * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
  */
-static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
+static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
                                   const uint8_t *msg_enc, size_t msg_enc_len,
                                   const uint8_t *sig_enc, size_t sig_enc_len)
 {
@@ -214,7 +223,8 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
     size_t num_polys_k = 2 * k;
     size_t num_polys_l = 1 * l;
     size_t num_polys_k_by_l = k * l;
-    uint8_t mu[ML_DSA_MU_BYTES];
+    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
+    const size_t mu_len = sizeof(mu);
     uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
     uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
     EVP_MD_CTX *md_ctx = NULL;
@@ -246,10 +256,17 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
     vector_init(&ct1_ntt, p + k, k);
 
     if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
-            || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt)
-            || !shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
-                            msg_enc, msg_enc_len, mu, sizeof(mu)))
+            || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
         goto err;
+    if (msg_is_mu) {
+        if (msg_enc_len != mu_len)
+            goto err;
+        mu_ptr = (uint8_t *)msg_enc;
+    } else {
+        if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
+                         msg_enc, msg_enc_len, mu_ptr, mu_len))
+            goto err;
+    }
     /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
     if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
                                  md_ctx, pub->shake256_md, params->tau))
@@ -275,7 +292,7 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
     vector_use_hint(&sig.hint, w_approx, gamma2, w1);
     ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
 
-    if (!shake_xof_3(md_ctx, pub->shake256_md, mu, sizeof(mu),
+    if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
                      w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
         goto err;
 
@@ -347,14 +364,14 @@ static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
  *
  * @returns 1 on success, or 0 on error.
  */
-int ossl_ml_dsa_sign(const ML_DSA_KEY *priv,
+int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
                      const uint8_t *msg, size_t msg_len,
                      const uint8_t *context, size_t context_len,
                      const uint8_t *rand, size_t rand_len, int encode,
                      unsigned char *sig, size_t *sig_len, size_t sig_size)
 {
     int ret = 1;
-    uint8_t m_tmp[1024], *m = m_tmp;
+    uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
     size_t m_len = 0;
 
     if (ossl_ml_dsa_key_get_priv(priv) == NULL)
@@ -362,13 +379,19 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv,
     if (sig != NULL) {
         if (sig_size < priv->params->sig_len)
             return 0;
-        m = msg_encode(msg, msg_len, context, context_len, encode,
-                       m_tmp, sizeof(m_tmp), &m_len);
-        if (m == NULL)
-            return 0;
-        ret = ml_dsa_sign_internal(priv, m, m_len, rand, rand_len, sig);
-        if (m != msg && m != m_tmp)
-            OPENSSL_free(m);
+        if (msg_is_mu) {
+            m = (uint8_t *)msg;
+            m_len = msg_len;
+        } else {
+            m = msg_encode(msg, msg_len, context, context_len, encode,
+                           m_tmp, sizeof(m_tmp), &m_len);
+            if (m == NULL)
+                return 0;
+            if (m != msg && m != m_tmp)
+                alloced_m = m;
+        }
+        ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
+        OPENSSL_free(alloced_m);
     }
     if (sig_len != NULL)
         *sig_len = priv->params->sig_len;
@@ -379,12 +402,12 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv,
  * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
  * @returns 1 on success, or 0 on error.
  */
-int ossl_ml_dsa_verify(const ML_DSA_KEY *pub,
+int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
                        const uint8_t *msg, size_t msg_len,
                        const uint8_t *context, size_t context_len, int encode,
                        const uint8_t *sig, size_t sig_len)
 {
-    uint8_t *m;
+    uint8_t *m, *alloced_m = NULL;
     size_t m_len;
     uint8_t m_tmp[1024];
     int ret = 0;
@@ -392,13 +415,19 @@ int ossl_ml_dsa_verify(const ML_DSA_KEY *pub,
     if (ossl_ml_dsa_key_get_pub(pub) == NULL)
         return 0;
 
-    m = msg_encode(msg, msg_len, context, context_len, encode,
-                   m_tmp, sizeof(m_tmp), &m_len);
-    if (m == NULL)
-        return 0;
+    if (msg_is_mu) {
+        m = (uint8_t *)msg;
+        m_len = msg_len;
+    } else {
+        m = msg_encode(msg, msg_len, context, context_len, encode,
+                       m_tmp, sizeof(m_tmp), &m_len);
+        if (m == NULL)
+            return 0;
+        if (m != msg && m != m_tmp)
+            alloced_m = m;
+    }
 
-    ret = ml_dsa_verify_internal(pub, m, m_len, sig, sig_len);
-    if (m != msg && m != m_tmp)
-        OPENSSL_free(m);
+    ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
+    OPENSSL_free(alloced_m);
     return ret;
 }
index b3fef85e39eec36a1e24686f8571fa657eb8ee1e..f66d93d06d35b663a9b091e102fd8510b41c5ee1 100644 (file)
@@ -105,12 +105,12 @@ __owur int ossl_ml_dsa_key_public_from_private(ML_DSA_KEY *key);
 __owur int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
 __owur int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
 
-__owur int ossl_ml_dsa_sign(const ML_DSA_KEY *priv,
+__owur int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
                             const uint8_t *msg, size_t msg_len,
                             const uint8_t *context, size_t context_len,
                             const uint8_t *rand, size_t rand_len, int encode,
                             unsigned char *sig, size_t *siglen, size_t sigsize);
-__owur int ossl_ml_dsa_verify(const ML_DSA_KEY *pub,
+__owur int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
                               const uint8_t *msg, size_t msg_len,
                               const uint8_t *context, size_t context_len,
                               int encode, const uint8_t *sig, size_t sig_len);
index c38a5a3ec43558ae94cd3980bb581fc6934b3829..f1f10d86871a341cc5793d4fa94da32cee168e9b 100644 (file)
@@ -77,11 +77,11 @@ static int ml_dsa_pairwise_test(const ML_DSA_KEY *key)
     memset(rnd, 0, sizeof(rnd));
     memset(sig, 0, sizeof(sig));
 
-    if (ossl_ml_dsa_sign(key, msg, sizeof(msg), NULL, 0, rnd, sizeof(rnd), 0,
+    if (ossl_ml_dsa_sign(key, 0, msg, sizeof(msg), NULL, 0, rnd, sizeof(rnd), 0,
                          sig, &sig_len, sizeof(sig)) <= 0)
         goto err;
 
-    if (ossl_ml_dsa_verify(key, msg, sizeof(msg), NULL, 0, 0,
+    if (ossl_ml_dsa_verify(key, 0, msg, sizeof(msg), NULL, 0, 0,
                            sig, sig_len) <= 0)
         goto err;
 
index 36d28d75dea1090f9e3d38c1c72a9fd9031285b3..f109a66f3360f94e37a7b4b27dfe59d5b0fa3314 100644 (file)
@@ -53,6 +53,7 @@ typedef struct {
     /* The Algorithm Identifier of the signature algorithm */
     uint8_t aid_buf[OSSL_MAX_ALGORITHM_ID_SIZE];
     size_t  aid_len;
+    int mu;     /* Flag indicating we should begin from \mu, not the message */
 } PROV_ML_DSA_CTX;
 
 static void ml_dsa_freectx(void *vctx)
@@ -143,6 +144,7 @@ static int ml_dsa_signverify_msg_init(void *vctx, void *vkey,
         return 0;
 
     set_alg_id_buffer(ctx);
+    ctx->mu = 0;
 
     return ml_dsa_set_ctx_params(ctx, params);
 }
@@ -164,6 +166,8 @@ static int ml_dsa_digest_signverify_init(void *vctx, const char *mdname,
         return 0;
     }
 
+    ctx->mu = 0;
+
     if (vkey == NULL && ctx->key != NULL)
         return ml_dsa_set_ctx_params(ctx, params);
 
@@ -193,7 +197,7 @@ static int ml_dsa_sign(void *vctx, uint8_t *sig, size_t *siglen, size_t sigsize,
                 return 0;
         }
     }
-    ret = ossl_ml_dsa_sign(ctx->key, msg, msg_len,
+    ret = ossl_ml_dsa_sign(ctx->key, ctx->mu, msg, msg_len,
                            ctx->context_string, ctx->context_string_len,
                            rnd, sizeof(rand_tmp), ctx->msg_encode,
                            sig, siglen, sigsize);
@@ -221,7 +225,7 @@ static int ml_dsa_verify(void *vctx, const uint8_t *sig, size_t siglen,
 
     if (!ossl_prov_is_running())
         return 0;
-    return ossl_ml_dsa_verify(ctx->key, msg, msg_len,
+    return ossl_ml_dsa_verify(ctx->key, ctx->mu, msg, msg_len,
                               ctx->context_string, ctx->context_string_len,
                               ctx->msg_encode, sig, siglen);
 }
@@ -273,6 +277,11 @@ static int ml_dsa_set_ctx_params(void *vctx, const OSSL_PARAM params[])
     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MESSAGE_ENCODING);
     if (p != NULL && !OSSL_PARAM_get_int(p, &pctx->msg_encode))
         return 0;
+
+    p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MU);
+    if (p != NULL && !OSSL_PARAM_get_int(p, &pctx->mu))
+        return 0;
+
     return 1;
 }
 
@@ -283,6 +292,7 @@ static const OSSL_PARAM *ml_dsa_settable_ctx_params(void *vctx,
         OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_CONTEXT_STRING, NULL, 0),
         OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_TEST_ENTROPY, NULL, 0),
         OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_DETERMINISTIC, 0),
+        OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_MU, 0),
         OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_MESSAGE_ENCODING, 0),
         OSSL_PARAM_END
     };