From: slontis Date: Mon, 10 Feb 2025 07:06:17 +0000 (+1100) Subject: ML-DSA: Change ossl_ml_dsa_key_public_from_private() to check that the X-Git-Tag: openssl-3.5.0-alpha1~493 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bd8954bfe50b3271475099bcb53565bcb1763e81;p=thirdparty%2Fopenssl.git ML-DSA: Change ossl_ml_dsa_key_public_from_private() to check that the decoded value of t0 matches the calculated value of t0. Reviewed-by: Viktor Dukhovni Reviewed-by: Tim Hudson (Merged from https://github.com/openssl/openssl/pull/26681) --- diff --git a/crypto/ml_dsa/ml_dsa_encoders.c b/crypto/ml_dsa/ml_dsa_encoders.c index 28c2b0c55df..b404ddc6357 100644 --- a/crypto/ml_dsa/ml_dsa_encoders.c +++ b/crypto/ml_dsa/ml_dsa_encoders.c @@ -768,7 +768,7 @@ int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len) PACKET pkt; /* When loading from an explicit key, drop the seed. */ - OPENSSL_free(key->seed); + OPENSSL_clear_free(key->seed, ML_DSA_SEED_BYTES); key->seed = NULL; /* Allow the key encoding to be already set to the provided pointer */ @@ -794,28 +794,28 @@ int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len) for (i = 0; i < l; ++i) if (!decode_fn(key->s1.poly + i, &pkt)) - return 0; + goto err; for (i = 0; i < k; ++i) if (!decode_fn(key->s2.poly + i, &pkt)) - return 0; + goto err; for (i = 0; i < k; ++i) if (!poly_decode_signed_two_to_power_12(key->t0.poly + i, &pkt)) - return 0; + goto err; if (PACKET_remaining(&pkt) != 0) - return 0; + goto err; if (key->priv_encoding == NULL && (key->priv_encoding = OPENSSL_memdup(in, in_len)) == NULL) - return 0; + goto err; /* * Computing the public key also computes its hash, which must be equal to * the |tr| value in the private key, else the key was corrupted. */ - if (ossl_ml_dsa_key_public_from_private(key) != 0 - && memcmp(input_tr, key->tr, sizeof(input_tr)) == 0) - return 1; + if (!ossl_ml_dsa_key_public_from_private(key) + || memcmp(input_tr, key->tr, sizeof(input_tr)) != 0) + goto err; - /* On error, reset the key back to uninitialised. */ - ossl_ml_dsa_key_reset(key); + return 1; + err: return 0; } diff --git a/crypto/ml_dsa/ml_dsa_key.c b/crypto/ml_dsa/ml_dsa_key.c index e0667f8902d..05bf142d979 100644 --- a/crypto/ml_dsa/ml_dsa_key.c +++ b/crypto/ml_dsa/ml_dsa_key.c @@ -355,19 +355,14 @@ int ossl_ml_dsa_key_public_from_private(ML_DSA_KEY *key) if (!vector_alloc(&t0, key->params->k)) /* t0 is already in the private key */ return 0; - if (!ossl_ml_dsa_key_pub_alloc(key)) /* allocate space for t1 */ - return 0; - - md_ctx = EVP_MD_CTX_new(); - if (md_ctx == NULL) - goto err; - - ret = public_from_private(key, md_ctx, &key->t1, &t0) + ret = ((md_ctx = EVP_MD_CTX_new())!= NULL) + && ossl_ml_dsa_key_pub_alloc(key) /* allocate space for t1 */ + && public_from_private(key, md_ctx, &key->t1, &t0) + && vector_equal(&t0, &key->t0) /* compare the generated t0 to the expected */ && ossl_ml_dsa_pk_encode(key) && shake_xof(md_ctx, key->shake256_md, key->pub_encoding, key->params->pk_len, key->tr, sizeof(key->tr)); -err: vector_free(&t0); EVP_MD_CTX_free(md_ctx); return ret; diff --git a/test/ml_dsa_test.c b/test/ml_dsa_test.c index 58bb2ede727..826bb46e39c 100644 --- a/test/ml_dsa_test.c +++ b/test/ml_dsa_test.c @@ -496,6 +496,36 @@ err: return ret; } +static int ml_dsa_priv_pub_bad_t0_test(void) +{ + int ret = 0; + EVP_PKEY *key = NULL; + ML_DSA_SIG_GEN_TEST_DATA *td = &ml_dsa_siggen_testdata[0]; + uint8_t *priv = OPENSSL_memdup(td->priv, td->priv_len); + + if (!TEST_ptr(priv)) + goto err; + memcpy(priv, td->priv, td->priv_len); + /* + * t0 is at the end of the encoding so corrupt it. + * This offset is the start of t0 (which is the last 416 * k bytes)) + */ + priv[td->priv_len - 6 * 416] ^= 1; + if (!TEST_true(ml_dsa_create_keypair(&key, td->alg, + priv, td->priv_len, NULL, 0, 0))) + goto err; + + priv[td->priv_len - 6 * 416] ^= 1; + if (!TEST_true(ml_dsa_create_keypair(&key, td->alg, + priv, td->priv_len, NULL, 0, 1))) + goto err; + ret = 1; + err: + OPENSSL_free(priv); + EVP_PKEY_free(key); + return ret; +} + const OPTIONS *test_get_options(void) { static const OPTIONS options[] = { @@ -539,6 +569,7 @@ int setup_tests(void) ADD_TEST(from_data_invalid_public_test); ADD_TEST(from_data_bad_input_test); ADD_TEST(ml_dsa_digest_sign_verify_test); + ADD_TEST(ml_dsa_priv_pub_bad_t0_test); return 1; }