From 90f0137453aaec5f09d26fda91c6025ae25e4130 Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Wed, 9 Apr 2025 09:35:20 -0400 Subject: [PATCH] Split the ML-DSA internal sigver functions Deconstruct the functions into 2 parts: - mu computation (if needed) - actual signing/verification Adds helper to compute mu that is split in 3 parts (init/update/finalize) where the update part can be used to feed the message to be signed or verified in chunks of any size. Signed-off-by: Simo Sorce Reviewed-by: Viktor Dukhovni Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/27342) --- crypto/ml_dsa/ml_dsa_sign.c | 330 ++++++++++++++++++++++-------------- 1 file changed, 200 insertions(+), 130 deletions(-) diff --git a/crypto/ml_dsa/ml_dsa_sign.c b/crypto/ml_dsa/ml_dsa_sign.c index 346f635094e..bbeb95e2a35 100644 --- a/crypto/ml_dsa/ml_dsa_sign.c +++ b/crypto/ml_dsa/ml_dsa_sign.c @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include "internal/common.h" #include "ml_dsa_local.h" #include "ml_dsa_key.h" #include "ml_dsa_matrix.h" @@ -43,12 +46,115 @@ static void signature_init(ML_DSA_SIG *sig, } /* - * FIPS 204, Algorithm 7, ML-DSA.Sign_internal() - * @returns 1 on success and 0 on failure. + * @brief: Auxiliary functions to compute ML-DSA's MU. + * This combines the steps of creating M' and concatenating it + * to the Public Key Hash to obtain MU. + * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5) as + * well as Algorithm 7 Step 6 (and algorithm 8 Step 7) + * + * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg + * Where ctx is the empty string by default and ctx_len <= 255. + * The message is appended to the encoded context. + * Finally a public key hash is prepended, and the whole is hashed + * to derive the mu value. + * + * @param key: A public or private ML-DSA key; + * @param encode: if not set, assumes that M' is provided raw and the + * following parameters are ignored. + * @param ctx An optional context to add to the message encoding. + * @param ctx_len The size of |ctx|. It must be in the range 0..255 + * @returns an EVP_MD_CTX if the operation is successful, NULL otherwise. + */ + +static EVP_MD_CTX *ml_dsa_mu_init(const ML_DSA_KEY *key, int encode, + const uint8_t *ctx, size_t ctx_len) +{ + EVP_MD_CTX *md_ctx; + uint8_t itb[2]; + + if (key == NULL) + return NULL; + + md_ctx = EVP_MD_CTX_new(); + if (md_ctx == NULL) + return NULL; + + /* H(.. */ + if (!EVP_DigestInit_ex2(md_ctx, key->shake256_md, NULL)) + goto err; + /* ..pk (= key->tr) */ + if (!EVP_DigestUpdate(md_ctx, key->tr, sizeof(key->tr))) + goto err; + /* M' = .. */ + if (encode) { + if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN) + goto err; + /* IntegerToBytes(0, 1) .. */ + itb[0] = 0; + /* || IntegerToBytes(|ctx|, 1) || .. */ + itb[1] = (uint8_t)ctx_len; + if (!EVP_DigestUpdate(md_ctx, itb, 2)) + goto err; + /* ctx || .. */ + if (!EVP_DigestUpdate(md_ctx, ctx, ctx_len)) + goto err; + /* .. msg) will follow in update and final functions */ + } + + return md_ctx; + +err: + EVP_MD_CTX_free(md_ctx); + return NULL; +} + +/* + * @brief: updates the internal ML-DSA hash with an additional message chunk. + * + * @param md_ctx: The hashing context + * @param msg: The next message chunk + * @param msg_len: The length of the msg buffer to process + * @returns 1 on success, 0 on error + */ +static int ml_dsa_mu_update(EVP_MD_CTX *md_ctx, + const uint8_t *msg, size_t msg_len) +{ + return EVP_DigestUpdate(md_ctx, msg, msg_len); +} + +/* + * @brief: finalizes the internal ML-DSA hash + * + * @param md_ctx: The hashing context + * @param mu: The output buffer for Mu + * @param mu_len: The size of the output buffer + * @returns 1 on success, 0 on error + */ +static int ml_dsa_mu_finalize(EVP_MD_CTX *md_ctx, uint8_t *mu, size_t mu_len) +{ + if (!ossl_assert(mu_len == ML_DSA_MU_BYTES)) { + ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH); + return 0; + } + return EVP_DigestSqueeze(md_ctx, mu, mu_len); +} + +/* + * @brief FIPS 204, Algorithm 7, ML-DSA.Sign_internal() + * + * This algorithm is decomposed in 2 steps, a set of functions to compute mu + * and then the actual signing function. + * + * @param priv: The private ML-DSA key + * @param mu: The pre-computed mu hash + * @param mu_len: The length of the mu buffer + * @param rnd: The random buffer + * @param rnd_len: The length of the random buffer + * @param out_sig: The output signature buffer + * @returns 1 on success, 0 on error */ -static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, - const uint8_t *encoded_msg, - size_t encoded_msg_len, +static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, + const uint8_t *mu, size_t mu_len, const uint8_t *rnd, size_t rnd_len, uint8_t *out_sig) { @@ -63,25 +169,28 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, size_t num_polys_k = 5 * k; size_t num_polys_l = 3 * l; size_t num_polys_k_by_l = k * l; - POLY *polys = NULL, *p, *c_ntt; + POLY *p, *c_ntt; VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y; MATRIX a_ntt; ML_DSA_SIG sig; - uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; - const size_t mu_len = sizeof(mu); uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES]; uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; size_t c_tilde_len = params->bit_strength >> 2; size_t kappa; + if (mu_len != ML_DSA_MU_BYTES) { + ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH); + return 0; + } + /* * Allocate a single blob for most of the variable size temporary variables. * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K). */ w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); alloc_len = w1_encoded_len - + sizeof(*polys) * (1 + num_polys_k + num_polys_l - + num_polys_k_by_l + num_polys_sig_k); + + sizeof(*p) * (1 + num_polys_k + num_polys_l + + num_polys_k_by_l + num_polys_sig_k); alloc = OPENSSL_malloc(alloc_len); if (alloc == NULL) return 0; @@ -110,17 +219,9 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt)) goto err; - if (msg_is_mu) { - if (encoded_msg_len != mu_len) - goto err; - mu_ptr = (uint8_t *)encoded_msg; - } else { - if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr), - encoded_msg, encoded_msg_len, mu_ptr, mu_len)) - goto err; - } + if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K), - rnd, rnd_len, mu_ptr, mu_len, + rnd, rnd_len, mu, mu_len, rho_prime, sizeof(rho_prime))) goto err; @@ -152,7 +253,7 @@ static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu, vector_high_bits(&w, gamma2, &w1); ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len); - if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len, + if (!shake_xof_2(md_ctx, priv->shake256_md, mu, mu_len, w1_encoded, w1_encoded_len, c_tilde, c_tilde_len)) break; @@ -202,15 +303,26 @@ err: } /* - * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal(). + * @brief FIPS 204, Algorithm 8, ML-DSA.Verify_internal(). + * + * This algorithm is decomposed in 2 steps, a set of functions to compute mu + * and then the actual verification function. + * + * @param pub: The public ML-DSA key + * @param mu: The pre-computed mu hash + * @param mu_len: The length of the mu buffer + * @param sig_enc: The encoded signature to be verified + * @param sig_enc_len: the encoded csignature length + * @returns 1 on success, 0 on error */ -static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, - const uint8_t *msg_enc, size_t msg_enc_len, - const uint8_t *sig_enc, size_t sig_enc_len) +static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, + const uint8_t *mu, size_t mu_len, + const uint8_t *sig_enc, + size_t sig_enc_len) { int ret = 0; uint8_t *alloc = NULL, *w1_encoded; - POLY *polys = NULL, *p, *c_ntt; + POLY *p, *c_ntt; MATRIX a_ntt; VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx; ML_DSA_SIG sig; @@ -223,21 +335,25 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, size_t num_polys_k = 2 * k; size_t num_polys_l = 1 * l; size_t num_polys_k_by_l = k * l; - uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu; - const size_t mu_len = sizeof(mu); uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4]; uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4]; EVP_MD_CTX *md_ctx = NULL; size_t c_tilde_len = params->bit_strength >> 2; uint32_t z_max; + if (mu_len != ML_DSA_MU_BYTES) { + ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH); + return 0; + } + + /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */ w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); alloc = OPENSSL_malloc(w1_encoded_len - + sizeof(*polys) * (1 + num_polys_k - + num_polys_l - + num_polys_k_by_l - + num_polys_sig)); + + sizeof(*p) * (1 + num_polys_k + + num_polys_l + + num_polys_k_by_l + + num_polys_sig)); if (alloc == NULL) return 0; md_ctx = EVP_MD_CTX_new(); @@ -258,16 +374,8 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params) || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt)) goto err; - if (msg_is_mu) { - if (msg_enc_len != mu_len) - goto err; - mu_ptr = (uint8_t *)msg_enc; - } else { - if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr), - msg_enc, msg_enc_len, mu_ptr, mu_len)) - goto err; - } - /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */ + + /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde)) */ if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len, md_ctx, pub->shake256_md, params->tau)) goto err; @@ -292,7 +400,7 @@ static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu, vector_use_hint(&sig.hint, w_approx, gamma2, w1); ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len); - if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len, + if (!shake_xof_3(md_ctx, pub->shake256_md, mu, mu_len, w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len)) goto err; @@ -304,61 +412,6 @@ err: return ret; } -/** - * @brief Encode a message - * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5). - * - * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg - * Where ctx is the empty string by default and ctx_len <= 255. - * - * Note this code could be shared with SLH_DSA - * - * @param msg A message to encode - * @param msg_len The size of |msg| - * @param ctx An optional context to add to the message encoding. - * @param ctx_len The size of |ctx|. It must be in the range 0..255 - * @param encode Use the Pure signature encoding if this is 1, and dont encode - * if this value is 0. - * @param tmp A small buffer that may be used if the message is small. - * @param tmp_len The size of |tmp| - * @param out_len The size of the returned encoded buffer. - * @returns A buffer containing the encoded message. If the passed in - * |tmp| buffer is big enough to hold the encoded message then it returns |tmp| - * otherwise it allocates memory which must be freed by the caller. If |encode| - * is 0 then it returns |msg|. NULL is returned if there is a failure. - */ -static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len, - const uint8_t *ctx, size_t ctx_len, int encode, - uint8_t *tmp, size_t tmp_len, size_t *out_len) -{ - uint8_t *encoded = NULL; - size_t encoded_len; - - if (encode == 0) { - /* Raw message */ - *out_len = msg_len; - return (uint8_t *)msg; - } - if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN) - return NULL; - - /* Pure encoding */ - encoded_len = 1 + 1 + ctx_len + msg_len; - *out_len = encoded_len; - if (encoded_len <= tmp_len) { - encoded = tmp; - } else { - encoded = OPENSSL_malloc(encoded_len); - if (encoded == NULL) - return NULL; - } - encoded[0] = 0; - encoded[1] = (uint8_t)ctx_len; - memcpy(&encoded[2], ctx, ctx_len); - memcpy(&encoded[2 + ctx_len], msg, msg_len); - return encoded; -} - /** * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign() * @@ -370,31 +423,43 @@ int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu, const uint8_t *rand, size_t rand_len, int encode, unsigned char *sig, size_t *sig_len, size_t sig_size) { - int ret = 1; - uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL; - size_t m_len = 0; + EVP_MD_CTX *md_ctx = NULL; + uint8_t mu[ML_DSA_MU_BYTES]; + const uint8_t *mu_ptr = mu; + size_t mu_len = sizeof(mu); + int ret = 0; if (ossl_ml_dsa_key_get_priv(priv) == NULL) return 0; - if (sig != NULL) { - if (sig_size < priv->params->sig_len) - return 0; - if (msg_is_mu) { - m = (uint8_t *)msg; - m_len = msg_len; - } else { - m = msg_encode(msg, msg_len, context, context_len, encode, - m_tmp, sizeof(m_tmp), &m_len); - if (m == NULL) - return 0; - if (m != msg && m != m_tmp) - alloced_m = m; - } - ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig); - OPENSSL_free(alloced_m); - } + if (sig_len != NULL) *sig_len = priv->params->sig_len; + + if (sig == NULL) + return (sig_len != NULL) ? 1 : 0; + + if (sig_size < priv->params->sig_len) + return 0; + + if (msg_is_mu) { + mu_ptr = msg; + mu_len = msg_len; + } else { + md_ctx = ml_dsa_mu_init(priv, encode, context, context_len); + if (md_ctx == NULL) + return 0; + + if (!ml_dsa_mu_update(md_ctx, msg, msg_len)) + goto err; + + if (!ml_dsa_mu_finalize(md_ctx, mu, mu_len)) + goto err; + } + + ret = ml_dsa_sign_internal(priv, mu_ptr, mu_len, rand, rand_len, sig); + +err: + EVP_MD_CTX_free(md_ctx); return ret; } @@ -407,27 +472,32 @@ int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu, const uint8_t *context, size_t context_len, int encode, const uint8_t *sig, size_t sig_len) { - uint8_t *m, *alloced_m = NULL; - size_t m_len; - uint8_t m_tmp[1024]; + EVP_MD_CTX *md_ctx = NULL; + uint8_t mu[ML_DSA_MU_BYTES]; + const uint8_t *mu_ptr = mu; + size_t mu_len = sizeof(mu); int ret = 0; if (ossl_ml_dsa_key_get_pub(pub) == NULL) return 0; if (msg_is_mu) { - m = (uint8_t *)msg; - m_len = msg_len; + mu_ptr = msg; + mu_len = msg_len; } else { - m = msg_encode(msg, msg_len, context, context_len, encode, - m_tmp, sizeof(m_tmp), &m_len); - if (m == NULL) + md_ctx = ml_dsa_mu_init(pub, encode, context, context_len); + if (md_ctx == NULL) return 0; - if (m != msg && m != m_tmp) - alloced_m = m; + + if (!ml_dsa_mu_update(md_ctx, msg, msg_len)) + goto err; + + if (!ml_dsa_mu_finalize(md_ctx, mu, mu_len)) + goto err; } - ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len); - OPENSSL_free(alloced_m); + ret = ml_dsa_verify_internal(pub, mu_ptr, mu_len, sig, sig_len); +err: + EVP_MD_CTX_free(md_ctx); return ret; } -- 2.47.2