From: Pauli Date: Wed, 5 Feb 2025 03:06:04 +0000 (+1100) Subject: ml-dsa: allow signature operations to be provided a μ value X-Git-Tag: openssl-3.5.0-alpha1~565 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=55738c152084857159541b89bf993ba08b0d1524;p=thirdparty%2Fopenssl.git ml-dsa: allow signature operations to be provided a μ value The μ value replaces the message and avoids some of the preliminary processes. This is part of FIPS 204. Reviewed-by: Tomas Mraz Reviewed-by: Shane Lontis Reviewed-by: Viktor Dukhovni (Merged from https://github.com/openssl/openssl/pull/26637) --- diff --git a/crypto/ml_dsa/ml_dsa_sign.c b/crypto/ml_dsa/ml_dsa_sign.c index 5b997f6a9f3..7a3b6a6f781 100644 --- a/crypto/ml_dsa/ml_dsa_sign.c +++ b/crypto/ml_dsa/ml_dsa_sign.c @@ -46,7 +46,7 @@ static void signature_init(ML_DSA_SIG *sig, * FIPS 204, Algorithm 7, ML-DSA.Sign_internal() * @returns 1 on success and 0 on failure. */ -static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, +static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, const uint8_t *encoded_msg, size_t encoded_msg_len, const uint8_t *rnd, size_t rnd_len, @@ -67,7 +67,8 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y; MATRIX a_ntt; ML_DSA_SIG sig; - uint8_t mu[ML_DSA_MU_BYTES]; + uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; + const size_t mu_len = sizeof(mu); uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES]; uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; size_t c_tilde_len = params->bit_strength >> 2; @@ -107,12 +108,20 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len); /* End of the allocated blob setup */ - if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt) - || !shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr), - encoded_msg, encoded_msg_len, mu, sizeof(mu)) - || !shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K), - rnd, rnd_len, mu, sizeof(mu), - rho_prime, sizeof(rho_prime))) + if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt)) + goto err; + if (msg_is_mu) { + if (encoded_msg_len != mu_len) + goto err; + mu_ptr = (uint8_t *)encoded_msg; + } else { + if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr), + encoded_msg, encoded_msg_len, mu_ptr, mu_len)) + goto err; + } + if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K), + rnd, rnd_len, mu_ptr, mu_len, + rho_prime, sizeof(rho_prime))) goto err; vector_copy(&s1_ntt, &priv->s1); @@ -143,7 +152,7 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, vector_high_bits(&w, gamma2, &w1); ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len); - if (!shake_xof_2(md_ctx, priv->shake256_md, mu, sizeof(mu), + if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len, w1_encoded, w1_encoded_len, c_tilde, c_tilde_len)) break; @@ -195,7 +204,7 @@ err: /* * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal(). */ -static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, +static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, const uint8_t *msg_enc, size_t msg_enc_len, const uint8_t *sig_enc, size_t sig_enc_len) { @@ -214,7 +223,8 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, size_t num_polys_k = 2 * k; size_t num_polys_l = 1 * l; size_t num_polys_k_by_l = k * l; - uint8_t mu[ML_DSA_MU_BYTES]; + uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; + const size_t mu_len = sizeof(mu); uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4]; EVP_MD_CTX *md_ctx = NULL; @@ -246,10 +256,17 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, vector_init(&ct1_ntt, p + k, k); if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params) - || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt) - || !shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr), - msg_enc, msg_enc_len, mu, sizeof(mu))) + || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt)) goto err; + if (msg_is_mu) { + if (msg_enc_len != mu_len) + goto err; + mu_ptr = (uint8_t *)msg_enc; + } else { + if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr), + msg_enc, msg_enc_len, mu_ptr, mu_len)) + goto err; + } /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */ if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len, md_ctx, pub->shake256_md, params->tau)) @@ -275,7 +292,7 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, vector_use_hint(&sig.hint, w_approx, gamma2, w1); ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len); - if (!shake_xof_3(md_ctx, pub->shake256_md, mu, sizeof(mu), + if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len, w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len)) goto err; @@ -347,14 +364,14 @@ static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len, * * @returns 1 on success, or 0 on error. */ -int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, +int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu, const uint8_t *msg, size_t msg_len, const uint8_t *context, size_t context_len, const uint8_t *rand, size_t rand_len, int encode, unsigned char *sig, size_t *sig_len, size_t sig_size) { int ret = 1; - uint8_t m_tmp[1024], *m = m_tmp; + uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL; size_t m_len = 0; if (ossl_ml_dsa_key_get_priv(priv) == NULL) @@ -362,13 +379,19 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, if (sig != NULL) { if (sig_size < priv->params->sig_len) return 0; - m = msg_encode(msg, msg_len, context, context_len, encode, - m_tmp, sizeof(m_tmp), &m_len); - if (m == NULL) - return 0; - ret = ml_dsa_sign_internal(priv, m, m_len, rand, rand_len, sig); - if (m != msg && m != m_tmp) - OPENSSL_free(m); + if (msg_is_mu) { + m = (uint8_t *)msg; + m_len = msg_len; + } else { + m = msg_encode(msg, msg_len, context, context_len, encode, + m_tmp, sizeof(m_tmp), &m_len); + if (m == NULL) + return 0; + if (m != msg && m != m_tmp) + alloced_m = m; + } + ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig); + OPENSSL_free(alloced_m); } if (sig_len != NULL) *sig_len = priv->params->sig_len; @@ -379,12 +402,12 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify() * @returns 1 on success, or 0 on error. */ -int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, +int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu, const uint8_t *msg, size_t msg_len, const uint8_t *context, size_t context_len, int encode, const uint8_t *sig, size_t sig_len) { - uint8_t *m; + uint8_t *m, *alloced_m = NULL; size_t m_len; uint8_t m_tmp[1024]; int ret = 0; @@ -392,13 +415,19 @@ int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, if (ossl_ml_dsa_key_get_pub(pub) == NULL) return 0; - m = msg_encode(msg, msg_len, context, context_len, encode, - m_tmp, sizeof(m_tmp), &m_len); - if (m == NULL) - return 0; + if (msg_is_mu) { + m = (uint8_t *)msg; + m_len = msg_len; + } else { + m = msg_encode(msg, msg_len, context, context_len, encode, + m_tmp, sizeof(m_tmp), &m_len); + if (m == NULL) + return 0; + if (m != msg && m != m_tmp) + alloced_m = m; + } - ret = ml_dsa_verify_internal(pub, m, m_len, sig, sig_len); - if (m != msg && m != m_tmp) - OPENSSL_free(m); + ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len); + OPENSSL_free(alloced_m); return ret; } diff --git a/include/crypto/ml_dsa.h b/include/crypto/ml_dsa.h index b3fef85e39e..f66d93d06d3 100644 --- a/include/crypto/ml_dsa.h +++ b/include/crypto/ml_dsa.h @@ -105,12 +105,12 @@ __owur int ossl_ml_dsa_key_public_from_private(ML_DSA_KEY *key); __owur int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len); __owur int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len); -__owur int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, +__owur int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu, const uint8_t *msg, size_t msg_len, const uint8_t *context, size_t context_len, const uint8_t *rand, size_t rand_len, int encode, unsigned char *sig, size_t *siglen, size_t sigsize); -__owur int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, +__owur int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu, const uint8_t *msg, size_t msg_len, const uint8_t *context, size_t context_len, int encode, const uint8_t *sig, size_t sig_len); diff --git a/providers/implementations/keymgmt/ml_dsa_kmgmt.c b/providers/implementations/keymgmt/ml_dsa_kmgmt.c index c38a5a3ec43..f1f10d86871 100644 --- a/providers/implementations/keymgmt/ml_dsa_kmgmt.c +++ b/providers/implementations/keymgmt/ml_dsa_kmgmt.c @@ -77,11 +77,11 @@ static int ml_dsa_pairwise_test(const ML_DSA_KEY *key) memset(rnd, 0, sizeof(rnd)); memset(sig, 0, sizeof(sig)); - if (ossl_ml_dsa_sign(key, msg, sizeof(msg), NULL, 0, rnd, sizeof(rnd), 0, + if (ossl_ml_dsa_sign(key, 0, msg, sizeof(msg), NULL, 0, rnd, sizeof(rnd), 0, sig, &sig_len, sizeof(sig)) <= 0) goto err; - if (ossl_ml_dsa_verify(key, msg, sizeof(msg), NULL, 0, 0, + if (ossl_ml_dsa_verify(key, 0, msg, sizeof(msg), NULL, 0, 0, sig, sig_len) <= 0) goto err; diff --git a/providers/implementations/signature/ml_dsa_sig.c b/providers/implementations/signature/ml_dsa_sig.c index 36d28d75dea..f109a66f336 100644 --- a/providers/implementations/signature/ml_dsa_sig.c +++ b/providers/implementations/signature/ml_dsa_sig.c @@ -53,6 +53,7 @@ typedef struct { /* The Algorithm Identifier of the signature algorithm */ uint8_t aid_buf[OSSL_MAX_ALGORITHM_ID_SIZE]; size_t aid_len; + int mu; /* Flag indicating we should begin from \mu, not the message */ } PROV_ML_DSA_CTX; static void ml_dsa_freectx(void *vctx) @@ -143,6 +144,7 @@ static int ml_dsa_signverify_msg_init(void *vctx, void *vkey, return 0; set_alg_id_buffer(ctx); + ctx->mu = 0; return ml_dsa_set_ctx_params(ctx, params); } @@ -164,6 +166,8 @@ static int ml_dsa_digest_signverify_init(void *vctx, const char *mdname, return 0; } + ctx->mu = 0; + if (vkey == NULL && ctx->key != NULL) return ml_dsa_set_ctx_params(ctx, params); @@ -193,7 +197,7 @@ static int ml_dsa_sign(void *vctx, uint8_t *sig, size_t *siglen, size_t sigsize, return 0; } } - ret = ossl_ml_dsa_sign(ctx->key, msg, msg_len, + ret = ossl_ml_dsa_sign(ctx->key, ctx->mu, msg, msg_len, ctx->context_string, ctx->context_string_len, rnd, sizeof(rand_tmp), ctx->msg_encode, sig, siglen, sigsize); @@ -221,7 +225,7 @@ static int ml_dsa_verify(void *vctx, const uint8_t *sig, size_t siglen, if (!ossl_prov_is_running()) return 0; - return ossl_ml_dsa_verify(ctx->key, msg, msg_len, + return ossl_ml_dsa_verify(ctx->key, ctx->mu, msg, msg_len, ctx->context_string, ctx->context_string_len, ctx->msg_encode, sig, siglen); } @@ -273,6 +277,11 @@ static int ml_dsa_set_ctx_params(void *vctx, const OSSL_PARAM params[]) p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MESSAGE_ENCODING); if (p != NULL && !OSSL_PARAM_get_int(p, &pctx->msg_encode)) return 0; + + p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MU); + if (p != NULL && !OSSL_PARAM_get_int(p, &pctx->mu)) + return 0; + return 1; } @@ -283,6 +292,7 @@ static const OSSL_PARAM *ml_dsa_settable_ctx_params(void *vctx, OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_CONTEXT_STRING, NULL, 0), OSSL_PARAM_octet_string(OSSL_SIGNATURE_PARAM_TEST_ENTROPY, NULL, 0), OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_DETERMINISTIC, 0), + OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_MU, 0), OSSL_PARAM_int(OSSL_SIGNATURE_PARAM_MESSAGE_ENCODING, 0), OSSL_PARAM_END };