From: Viktor Dukhovni Date: Sun, 5 Jan 2025 12:32:23 +0000 (+1100) Subject: Add ML-DSA-44 and ML-DSA-87, fix endian issues & add fixups X-Git-Tag: openssl-3.5.0-alpha1~605 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a2391f3aa5c701c3b3d0d337d24181c8d55e87e7;p=thirdparty%2Fopenssl.git Add ML-DSA-44 and ML-DSA-87, fix endian issues & add fixups - Make data encoding work on big-endian systems. - Fix some ML-DSA-44 specific bugs related to w1-vector bits per-coefficient, overall size and high-bits rounding. - Use "do { ... } while (pointer < end)" style consistently. - Drop redundant reference counting of provided keys. - Add parameter blocks for ML-DSA-44 and ML-DSA-87 and turn on associated provider glue. These now pass both keygen and siggen tests (to be added separately). Reviewed-by: Tim Hudson Reviewed-by: Matt Caswell (Merged from https://github.com/openssl/openssl/pull/26127) --- diff --git a/crypto/ml_dsa/ml_dsa_encoders.c b/crypto/ml_dsa/ml_dsa_encoders.c index c72091bfbbe..3ad1a04b3e4 100644 --- a/crypto/ml_dsa/ml_dsa_encoders.c +++ b/crypto/ml_dsa/ml_dsa_encoders.c @@ -7,6 +7,7 @@ * https://www.openssl.org/source/license.html */ +#include #include #include "ml_dsa_local.h" #include "ml_dsa_key.h" @@ -14,6 +15,9 @@ #include "ml_dsa_sign.h" #include "internal/packet.h" +/* Cast mod_sub result in support of left-shifts that create 64-bit values. */ +#define mod_sub_64(a, b) ((uint64_t) mod_sub(a, b)) + typedef int (ENCODE_FN)(const POLY *s, WPACKET *pkt); typedef int (DECODE_FN)(POLY *s, PACKET *pkt); @@ -52,12 +56,12 @@ static int poly_encode_4_bits(const POLY *p, WPACKET *pkt) if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out)) return 0; - while (in < end) { + do { uint32_t z0 = *in++; uint32_t z1 = *in++; *out++ = z0 | (z1 << 4); - } + } while (in < end); return 1; } @@ -86,19 +90,19 @@ static int poly_encode_6_bits(const POLY *p, WPACKET *pkt) uint8_t *out; const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS; - if (!WPACKET_allocate_bytes(pkt, 32 * 3, &out)) + if (!WPACKET_allocate_bytes(pkt, 32 * 6, &out)) return 0; - while (in < end) { + do { uint32_t c0 = *in++; uint32_t c1 = *in++; uint32_t c2 = *in++; uint32_t c3 = *in++; *out++ = c0 | (c1 << 6); - *out++ = c1 >> 4 | (c2 << 4); - *out++ = c3; - } + *out++ = (c1 >> 2) | (c2 << 4); + *out++ = (c2 >> 4) | (c3 << 2); + } while (in < end); return 1; } @@ -129,7 +133,7 @@ static int poly_encode_10_bits(const POLY *p, WPACKET *pkt) if (!WPACKET_allocate_bytes(pkt, 32 * 10, &out)) return 0; - while (in < end) { + do { uint32_t c0 = *in++; uint32_t c1 = *in++; uint32_t c2 = *in++; @@ -140,7 +144,7 @@ static int poly_encode_10_bits(const POLY *p, WPACKET *pkt) *out++ = (uint8_t)((c1 >> 6) | (c2 << 4)); *out++ = (uint8_t)((c2 >> 4) | (c3 << 6)); *out++ = (uint8_t)(c3 >> 2); - } + } while (in < end); return 1; } @@ -155,24 +159,23 @@ static int poly_encode_10_bits(const POLY *p, WPACKET *pkt) */ static int poly_decode_10_bits(POLY *p, PACKET *pkt) { - int ret = 0; const uint8_t *in = NULL; - uint32_t v, mask = 0x3ff; /* 10 bits */ + uint32_t v, w, mask = 0x3ff; /* 10 bits */ uint32_t *out = p->coeff, *end = out + ML_DSA_NUM_POLY_COEFFICIENTS; do { if (!PACKET_get_bytes(pkt, &in, 5)) - goto err; - /* put first 4 bytes into v, 5th byte is accessed directly as in[4] */ - memcpy(&v, in, 4); + return 0; + + in = OSSL_CRYPTO_load_u32_le(&v, in); + w = *in; + *out++ = v & mask; *out++ = (v >> 10) & mask; *out++ = (v >> 20) & mask; - *out++ = (v >> 30) | (((uint32_t)in[4]) << 2); + *out++ = (v >> 30) | (w << 2); } while (out < end); - ret = 1; -err: - return ret; + return 1; } /* @@ -199,12 +202,11 @@ static int poly_encode_signed_4(const POLY *p, WPACKET *pkt) if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out)) return 0; - while (in < end) { - uint32_t z0 = mod_sub(4, *in++); /* 0..8 */ - uint32_t z1 = mod_sub(4, *in++); /* 0..8 */ + do { + uint32_t z = mod_sub(4, *in++); - *out++ = z0 | (z1 << 4); - } + *out++ = z | (mod_sub(4, *in++) << 4); + } while (in < end); return 1; } @@ -228,7 +230,7 @@ static int poly_decode_signed_4(POLY *p, PACKET *pkt) for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) { if (!PACKET_get_bytes(pkt, &in, 4)) goto err; - memcpy(&v, in, 4); + in = OSSL_CRYPTO_load_u32_le(&v, in); /* * None of the nibbles may be >= 9. So if the MSB of any nibble is set, @@ -288,20 +290,21 @@ static int poly_encode_signed_2(const POLY *p, WPACKET *pkt) if (!WPACKET_allocate_bytes(pkt, 32 * 3, &out)) return 0; - while (in < end) { - uint32_t z0 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z1 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z2 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z3 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z4 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z5 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z6 = mod_sub(2, *in++); /* 0..7 */ - uint32_t z7 = mod_sub(2, *in++); /* 0..7 */ - - *out++ = (uint8_t)z0 | (uint8_t)(z1 << 3) | (uint8_t)(z2 << 6); - *out++ = (uint8_t)(z2 >> 2) | (uint8_t)(z3 << 1) | (uint8_t)(z4 << 4) | (uint8_t)(z5 << 7); - *out++ = (uint8_t)(z5 >> 1) | (uint8_t)(z6 << 2) | (uint8_t)(z7 << 5); - } + do { + uint32_t z; + + z = mod_sub(2, *in++); + z |= mod_sub(2, *in++) << 3; + z |= mod_sub(2, *in++) << 6; + z |= mod_sub(2, *in++) << 9; + z |= mod_sub(2, *in++) << 12; + z |= mod_sub(2, *in++) << 15; + z |= mod_sub(2, *in++) << 18; + z |= mod_sub(2, *in++) << 21; + + out = OSSL_CRYPTO_store_u16_le(out, (uint16_t) z); + *out++ = (uint8_t) (z >> 16); + } while (in < end); return 1; } @@ -318,14 +321,16 @@ static int poly_encode_signed_2(const POLY *p, WPACKET *pkt) static int poly_decode_signed_2(POLY *p, PACKET *pkt) { int i, ret = 0; - uint32_t v = 0, *out = p->coeff; + uint32_t u = 0, v = 0, *out = p->coeff; uint32_t msbs, mask; const uint8_t *in; for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) { if (!PACKET_get_bytes(pkt, &in, 3)) goto err; - memcpy(&v, in, 3); + memcpy(&u, in, 3); + OSSL_CRYPTO_load_u32_le(&v, (uint8_t *)&u); + /* * Each octal value (3 bits) must be <= 4, So if the MSB is set then the * bottom 2 bits must not be set. @@ -381,22 +386,26 @@ static int poly_encode_signed_two_to_power_12(const POLY *p, WPACKET *pkt) static const uint32_t range = 1u << 12; const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS; - while (in < end) { - uint64_t z0 = mod_sub(range, *in++); /* < 2^13 */ - uint64_t z1 = mod_sub(range, *in++); - uint64_t z2 = mod_sub(range, *in++); - uint64_t z3 = mod_sub(range, *in++); - uint64_t z4 = mod_sub(range, *in++); - uint64_t z5 = mod_sub(range, *in++); - uint64_t z6 = mod_sub(range, *in++); - uint64_t z7 = mod_sub(range, *in++); - uint64_t a1 = (z0) | (z1 << 13) | (z2 << 26) | (z3 << 39) | (z4 << 52); - uint64_t a2 = (z4 >> 12) | (z5 << 1) | (z6 << 14) | (z7 << 27); - - if (!WPACKET_memcpy(pkt, &a1, 8) - || !WPACKET_memcpy(pkt, &a2, 5)) + do { + uint8_t *out; + uint64_t a1, a2; + + if (!WPACKET_allocate_bytes(pkt, 13, &out)) return 0; - } + + a1 = mod_sub_64(range, *in++); + a1 |= mod_sub_64(range, *in++) << 13; + a1 |= mod_sub_64(range, *in++) << 26; + a1 |= mod_sub_64(range, *in++) << 39; + a1 |= (a2 = mod_sub_64(range, *in++)) << 52; + a2 = (a2 >> 12) | (mod_sub_64(range, *in++) << 1); + a2 |= mod_sub_64(range, *in++) << 14; + a2 |= mod_sub_64(range, *in++) << 27; + + out = OSSL_CRYPTO_store_u64_le(out, a1); + out = OSSL_CRYPTO_store_u32_le(out, (uint32_t) a2); + *out = (uint8_t) (a2 >> 32); + } while (in < end); return 1; } @@ -412,17 +421,20 @@ static int poly_encode_signed_two_to_power_12(const POLY *p, WPACKET *pkt) static int poly_decode_signed_two_to_power_12(POLY *p, PACKET *pkt) { int i, ret = 0; - uint64_t a1 = 0, a2 = 0; uint32_t *out = p->coeff; const uint8_t *in; static const uint32_t range = 1u << 12; static const uint32_t mask_13_bits = (1u << 13) - 1; for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) { + uint64_t a1; + uint32_t a2, b13; + if (!PACKET_get_bytes(pkt, &in, 13)) goto err; - memcpy(&a1, in, 8); - memcpy(&a2, in + 8, 5); + in = OSSL_CRYPTO_load_u64_le(&a1, in); + in = OSSL_CRYPTO_load_u32_le(&a2, in); + b13 = (uint32_t) *in; *out++ = mod_sub(range, a1 & mask_13_bits); *out++ = mod_sub(range, (a1 >> 13) & mask_13_bits); @@ -431,7 +443,7 @@ static int poly_decode_signed_two_to_power_12(POLY *p, PACKET *pkt) *out++ = mod_sub(range, (a1 >> 52) | ((a2 << 12) & mask_13_bits)); *out++ = mod_sub(range, (a2 >> 1) & mask_13_bits); *out++ = mod_sub(range, (a2 >> 14) & mask_13_bits); - *out++ = mod_sub(range, (a2 >> 27) & mask_13_bits); + *out++ = mod_sub(range, (a2 >> 27) | (b13 << 5)); } ret = 1; err: @@ -463,22 +475,22 @@ static int poly_encode_signed_two_to_power_19(const POLY *p, WPACKET *pkt) static const uint32_t range = 1u << 19; const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS; - while (in < end) { - uint32_t z0 = mod_sub(range, *in++); /* < 2^20 */ - uint32_t z1 = mod_sub(range, *in++); - uint32_t z2 = mod_sub(range, *in++); - uint32_t z3 = mod_sub(range, *in++); - - z0 |= (z1 << 20); - z1 >>= 12; - z1 |= (z2 << 8) | (z3 << 28); - z3 >>= 4; + do { + uint32_t z0, z1, z2; + uint8_t *out; - if (!WPACKET_memcpy(pkt, &z0, sizeof(z0)) - || !WPACKET_memcpy(pkt, &z1, sizeof(z1)) - || !WPACKET_memcpy(pkt, &z3, 2)) + if (!WPACKET_allocate_bytes(pkt, 10, &out)) return 0; - } + + z0 = mod_sub(range, *in++); + z0 |= (z1 = mod_sub(range, *in++)) << 20; + z1 = (z1 >> 12) | (mod_sub(range, *in++) << 8); + z1 |= (z2 = mod_sub(range, *in++)) << 28; + + out = OSSL_CRYPTO_store_u32_le(out, z0); + out = OSSL_CRYPTO_store_u32_le(out, z1); + out = OSSL_CRYPTO_store_u16_le(out, (uint16_t) (z2 >> 4)); + } while (in < end); return 1; } @@ -494,18 +506,20 @@ static int poly_encode_signed_two_to_power_19(const POLY *p, WPACKET *pkt) static int poly_decode_signed_two_to_power_19(POLY *p, PACKET *pkt) { int i, ret = 0; - uint32_t a1, a2, a3 = 0; uint32_t *out = p->coeff; const uint8_t *in; static const uint32_t range = 1u << 19; static const uint32_t mask_20_bits = (1u << 20) - 1; for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 4); i++) { + uint32_t a1, a2; + uint16_t a3; + if (!PACKET_get_bytes(pkt, &in, 10)) goto err; - memcpy(&a1, in, 4); - memcpy(&a2, in + 4, 4); - memcpy(&a3, in + 8, 2); + in = OSSL_CRYPTO_load_u32_le(&a1, in); + in = OSSL_CRYPTO_load_u32_le(&a2, in); + in = OSSL_CRYPTO_load_u16_le(&a3, in); *out++ = mod_sub(range, a1 & mask_20_bits); *out++ = mod_sub(range, (a1 >> 20) | ((a2 & 0xFF) << 12)); @@ -542,22 +556,22 @@ static int poly_encode_signed_two_to_power_17(const POLY *p, WPACKET *pkt) static const uint32_t range = 1u << 17; const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS; - while (in < end) { - uint32_t z0 = mod_sub(range, *in++); /* < 2^18 */ - uint32_t z1 = mod_sub(range, *in++); - uint32_t z2 = mod_sub(range, *in++); - uint32_t z3 = mod_sub(range, *in++); - - z0 |= (z1 << 18); - z1 >>= 14; - z1 |= (z2 << 4) | (z3 << 22); - z3 >>= 10; + do { + uint8_t *out; + uint32_t z0, z1, z2; - if (!WPACKET_memcpy(pkt, &z0, sizeof(z0)) - || !WPACKET_memcpy(pkt, &z1, sizeof(z1)) - || !WPACKET_memcpy(pkt, &z3, 1)) + if (!WPACKET_allocate_bytes(pkt, 9, &out)) return 0; - } + + z0 = mod_sub(range, *in++); + z0 |= (z1 = mod_sub(range, *in++)) << 18; + z1 = (z1 >> 14) | (mod_sub(range, *in++) << 4); + z1 |= (z2 = mod_sub(range, *in++)) << 22; + + out = OSSL_CRYPTO_store_u32_le(out, z0); + out = OSSL_CRYPTO_store_u32_le(out, z1); + *out = z2 >> 10; + } while (in < end); return 1; } @@ -572,29 +586,27 @@ static int poly_encode_signed_two_to_power_17(const POLY *p, WPACKET *pkt) */ static int poly_decode_signed_two_to_power_17(POLY *p, PACKET *pkt) { - int ret = 0; - uint32_t a1, a2, a3 = 0; uint32_t *out = p->coeff; const uint32_t *end = out + ML_DSA_NUM_POLY_COEFFICIENTS; const uint8_t *in; static const uint32_t range = 1u << 17; static const uint32_t mask_18_bits = (1u << 18) - 1; - while (out < end) { - if (!PACKET_get_bytes(pkt, &in, 10)) - goto err; - memcpy(&a1, in, 4); - memcpy(&a2, in + 4, 4); - memcpy(&a3, in + 8, 1); + do { + uint32_t a1, a2, a3; + + if (!PACKET_get_bytes(pkt, &in, 9)) + return 0; + in = OSSL_CRYPTO_load_u32_le(&a1, in); + in = OSSL_CRYPTO_load_u32_le(&a2, in); + a3 = (uint32_t) *in; *out++ = mod_sub(range, a1 & mask_18_bits); *out++ = mod_sub(range, (a1 >> 18) | ((a2 & 0xF) << 14)); *out++ = mod_sub(range, (a2 >> 4) & mask_18_bits); *out++ = mod_sub(range, (a2 >> 22) | (a3 << 10)); - } - ret = 1; - err: - return ret; + } while (out < end); + return 1; } /* diff --git a/crypto/ml_dsa/ml_dsa_key.c b/crypto/ml_dsa/ml_dsa_key.c index 57eb599507f..55f524df2d0 100644 --- a/crypto/ml_dsa/ml_dsa_key.c +++ b/crypto/ml_dsa/ml_dsa_key.c @@ -37,10 +37,6 @@ ML_DSA_KEY *ossl_ml_dsa_key_new(OSSL_LIB_CTX *libctx, const char *alg) poly_sz = sizeof(POLY) * (params->k * 3 + params->l); ret = OPENSSL_zalloc(sizeof(*ret) + poly_sz); if (ret != NULL) { - if (!CRYPTO_NEW_REF(&ret->references, 1)) { - OPENSSL_free(ret); - return NULL; - } ret->libctx = libctx; ret->params = params; poly = (POLY *)((uint8_t *)ret + sizeof(*ret)); @@ -57,40 +53,15 @@ ML_DSA_KEY *ossl_ml_dsa_key_new(OSSL_LIB_CTX *libctx, const char *alg) */ void ossl_ml_dsa_key_free(ML_DSA_KEY *key) { - int i; - if (key == NULL) return; - CRYPTO_DOWN_REF(&key->references, &i); - REF_PRINT_COUNT("ML_DSA_KEY", key); - if (i > 0) - return; - REF_ASSERT_ISNT(i < 0); - OPENSSL_free(key->pub_encoding); OPENSSL_free(key->priv_encoding); OPENSSL_free(key->propq); - CRYPTO_FREE_REF(&key->references); OPENSSL_free(key); } -/* - * @brief Increase the reference count for a ML_DSA_KEY object. - * @returns 1 on success or 0 otherwise. - */ -int ossl_ml_dsa_key_up_ref(ML_DSA_KEY *key) -{ - int i; - - if (CRYPTO_UP_REF(&key->references, &i) <= 0) - return 0; - - REF_PRINT_COUNT("ML_DSA_KEY", key); - REF_ASSERT_ISNT(i < 2); - return ((i > 1) ? 1 : 0); -} - /** * @brief Are 2 keys equal? * diff --git a/crypto/ml_dsa/ml_dsa_key_compress.c b/crypto/ml_dsa/ml_dsa_key_compress.c index ffeb2b43e45..d109fe0b560 100644 --- a/crypto/ml_dsa/ml_dsa_key_compress.c +++ b/crypto/ml_dsa/ml_dsa_key_compress.c @@ -61,7 +61,7 @@ void ossl_ml_dsa_key_compress_power2_round(uint32_t r, uint32_t *r1, uint32_t *r */ uint32_t ossl_ml_dsa_key_compress_high_bits(uint32_t r, uint32_t gamma2) { - uint32_t r1 = (r + 127) >> 7; + int32_t r1 = (r + 127) >> 7; /* TODO - figure out what this is doing */ if (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV32) { @@ -141,7 +141,7 @@ uint32_t ossl_ml_dsa_key_compress_use_hint(uint32_t hint, uint32_t r, return r0 > 0 ? (r1 + 1) & 15 : (r1 - 1) & 15; } else { /* m = 44 if gamma2 = ((q - 1) / 88) */ - if (r1 > 0) + if (r0 > 0) return (r1 == 43) ? 0 : r1 + 1; else return (r1 == 0) ? 43 : r1 - 1; diff --git a/crypto/ml_dsa/ml_dsa_params.c b/crypto/ml_dsa/ml_dsa_params.c index e1a86943f99..fa56fbb7a30 100644 --- a/crypto/ml_dsa/ml_dsa_params.c +++ b/crypto/ml_dsa/ml_dsa_params.c @@ -11,6 +11,19 @@ #include "ml_dsa_local.h" #include "ml_dsa_params.h" +/* See FIPS 204 Section 4 Table 1 & Table 2 */ +#define ML_DSA_44_TAU 39 +#define ML_DSA_44_LAMBDA 128 +#define ML_DSA_44_K 4 +#define ML_DSA_44_L 4 +#define ML_DSA_44_ETA ML_DSA_ETA_2 +#define ML_DSA_44_BETA 78 +#define ML_DSA_44_OMEGA 80 +#define ML_DSA_44_SECURITY_CATEGORY 2 +#define ML_DSA_44_PRIV_LEN 2560 +#define ML_DSA_44_PUB_LEN 1312 +#define ML_DSA_44_SIG_LEN 2420 + /* See FIPS 204 Section 4 Table 1 & Table 2 */ #define ML_DSA_65_TAU 49 #define ML_DSA_65_LAMBDA 192 @@ -24,7 +37,35 @@ #define ML_DSA_65_PUB_LEN 1952 #define ML_DSA_65_SIG_LEN 3309 +/* See FIPS 204 Section 4 Table 1 & Table 2 */ +#define ML_DSA_87_TAU 60 +#define ML_DSA_87_LAMBDA 256 +#define ML_DSA_87_K 8 +#define ML_DSA_87_L 7 +#define ML_DSA_87_ETA ML_DSA_ETA_2 +#define ML_DSA_87_BETA 120 +#define ML_DSA_87_OMEGA 75 +#define ML_DSA_87_SECURITY_CATEGORY 5 +#define ML_DSA_87_PRIV_LEN 4896 +#define ML_DSA_87_PUB_LEN 2592 +#define ML_DSA_87_SIG_LEN 4627 + static const ML_DSA_PARAMS ml_dsa_params[] = { + { "ML-DSA-44", + ML_DSA_44_TAU, + ML_DSA_44_LAMBDA, + ML_DSA_GAMMA1_TWO_POWER_17, + ML_DSA_GAMMA2_Q_MINUS1_DIV88, + ML_DSA_44_K, + ML_DSA_44_L, + ML_DSA_44_ETA, + ML_DSA_44_BETA, + ML_DSA_44_OMEGA, + ML_DSA_44_SECURITY_CATEGORY, + ML_DSA_44_PRIV_LEN, + ML_DSA_44_PUB_LEN, + ML_DSA_44_SIG_LEN + }, { "ML-DSA-65", ML_DSA_65_TAU, ML_DSA_65_LAMBDA, @@ -40,6 +81,21 @@ static const ML_DSA_PARAMS ml_dsa_params[] = { ML_DSA_65_PUB_LEN, ML_DSA_65_SIG_LEN }, + { "ML-DSA-87", + ML_DSA_87_TAU, + ML_DSA_87_LAMBDA, + ML_DSA_GAMMA1_TWO_POWER_19, + ML_DSA_GAMMA2_Q_MINUS1_DIV32, + ML_DSA_87_K, + ML_DSA_87_L, + ML_DSA_87_ETA, + ML_DSA_87_BETA, + ML_DSA_87_OMEGA, + ML_DSA_87_SECURITY_CATEGORY, + ML_DSA_87_PRIV_LEN, + ML_DSA_87_PUB_LEN, + ML_DSA_87_SIG_LEN + }, {NULL}, }; diff --git a/crypto/ml_dsa/ml_dsa_sign.c b/crypto/ml_dsa/ml_dsa_sign.c index f7abc1bad01..019e2ce02d5 100644 --- a/crypto/ml_dsa/ml_dsa_sign.c +++ b/crypto/ml_dsa/ml_dsa_sign.c @@ -59,16 +59,12 @@ static int ml_dsa_sign_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *priv, uint32_t k = params->k, l = params->l; uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2; uint8_t *alloc = NULL, *w1_encoded; - size_t w1_encoded_len = 128 * k; + size_t alloc_len, w1_encoded_len; size_t num_polys_sig_k = 2 * k; size_t num_polys_k = 5 * k; size_t num_polys_l = 3 * l; size_t num_polys_k_by_l = k * l; POLY *polys = NULL, *p, *c_ntt; - size_t alloc_len = w1_encoded_len - + sizeof(*polys) - * (1 + num_polys_k + num_polys_l - + num_polys_k_by_l + num_polys_sig_k); VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y; MATRIX a_ntt; ML_DSA_SIG sig; @@ -82,6 +78,10 @@ static int ml_dsa_sign_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *priv, * Allocate a single blob for most of the variable size temporary variables. * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K). */ + w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); + alloc_len = w1_encoded_len + + sizeof(*polys) * (1 + num_polys_k + num_polys_l + + num_polys_k_by_l + num_polys_sig_k); alloc = OPENSSL_malloc(alloc_len); if (alloc == NULL) return 0; @@ -139,7 +139,7 @@ static int ml_dsa_sign_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *priv, vector_high_bits(&w, gamma2, &w1); ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len); - if (!shake_xof_2(h_ctx, mu, sizeof(mu), w1_encoded, 128 * k, + if (!shake_xof_2(h_ctx, mu, sizeof(mu), w1_encoded, w1_encoded_len, c_tilde, c_tilde_len)) break; @@ -203,7 +203,8 @@ static int ml_dsa_verify_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *pub, const ML_DSA_PARAMS *params = ctx->params; uint32_t k = pub->params->k; uint32_t l = pub->params->l; - size_t w1_encoded_len = 128 * k; + uint32_t gamma2 = params->gamma2; + size_t w1_encoded_len; size_t num_polys_sig = k + l; size_t num_polys_k = 2 * k; size_t num_polys_l = 1 * l; @@ -216,6 +217,7 @@ static int ml_dsa_verify_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *pub, uint32_t z_max; /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */ + w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128); alloc = OPENSSL_malloc(w1_encoded_len + sizeof(*polys) * (1 + num_polys_k + num_polys_l @@ -261,8 +263,8 @@ static int ml_dsa_verify_internal(ML_DSA_CTX *ctx, const ML_DSA_KEY *pub, /* compute w1_encoded */ w1 = w_approx; - vector_use_hint(&sig.hint, w_approx, params->gamma2, w1); - ossl_ml_dsa_w1_encode(w1, params->gamma2, w1_encoded, w1_encoded_len); + vector_use_hint(&sig.hint, w_approx, gamma2, w1); + ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len); if (!shake_xof_3(h_ctx, mu, sizeof(mu), w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len)) diff --git a/providers/defltprov.c b/providers/defltprov.c index 2b8b8c5c00e..ccc1e8e7e9b 100644 --- a/providers/defltprov.c +++ b/providers/defltprov.c @@ -448,7 +448,9 @@ static const OSSL_ALGORITHM deflt_signature[] = { # endif #endif #ifndef OPENSSL_NO_ML_DSA + { PROV_NAMES_ML_DSA_44, "provider=default", ossl_ml_dsa_44_signature_functions }, { PROV_NAMES_ML_DSA_65, "provider=default", ossl_ml_dsa_65_signature_functions }, + { PROV_NAMES_ML_DSA_87, "provider=default", ossl_ml_dsa_87_signature_functions }, #endif { PROV_NAMES_HMAC, "provider=default", ossl_mac_legacy_hmac_signature_functions }, { PROV_NAMES_SIPHASH, "provider=default", @@ -513,8 +515,12 @@ static const OSSL_ALGORITHM deflt_keymgmt[] = { # endif #endif #ifndef OPENSSL_NO_ML_DSA + { PROV_NAMES_ML_DSA_44, "provider=default", ossl_ml_dsa_44_keymgmt_functions, + PROV_DESCS_ML_DSA_44 }, { PROV_NAMES_ML_DSA_65, "provider=default", ossl_ml_dsa_65_keymgmt_functions, PROV_DESCS_ML_DSA_65 }, + { PROV_NAMES_ML_DSA_87, "provider=default", ossl_ml_dsa_87_keymgmt_functions, + PROV_DESCS_ML_DSA_87 }, #endif /* OPENSSL_NO_ML_DSA */ { PROV_NAMES_TLS1_PRF, "provider=default", ossl_kdf_keymgmt_functions, PROV_DESCS_TLS1_PRF_SIGN }, diff --git a/providers/implementations/include/prov/implementations.h b/providers/implementations/include/prov/implementations.h index f1985805a26..f4b0be387c1 100644 --- a/providers/implementations/include/prov/implementations.h +++ b/providers/implementations/include/prov/implementations.h @@ -321,7 +321,9 @@ extern const OSSL_DISPATCH ossl_cmac_legacy_keymgmt_functions[]; #ifndef OPENSSL_NO_SM2 extern const OSSL_DISPATCH ossl_sm2_keymgmt_functions[]; #endif +extern const OSSL_DISPATCH ossl_ml_dsa_44_keymgmt_functions[]; extern const OSSL_DISPATCH ossl_ml_dsa_65_keymgmt_functions[]; +extern const OSSL_DISPATCH ossl_ml_dsa_87_keymgmt_functions[]; /* Key Exchange */ extern const OSSL_DISPATCH ossl_dh_keyexch_functions[]; @@ -384,7 +386,9 @@ extern const OSSL_DISPATCH ossl_mac_legacy_siphash_signature_functions[]; extern const OSSL_DISPATCH ossl_mac_legacy_poly1305_signature_functions[]; extern const OSSL_DISPATCH ossl_mac_legacy_cmac_signature_functions[]; extern const OSSL_DISPATCH ossl_sm2_signature_functions[]; +extern const OSSL_DISPATCH ossl_ml_dsa_44_signature_functions[]; extern const OSSL_DISPATCH ossl_ml_dsa_65_signature_functions[]; +extern const OSSL_DISPATCH ossl_ml_dsa_87_signature_functions[]; /* Asym Cipher */ extern const OSSL_DISPATCH ossl_rsa_asym_cipher_functions[]; diff --git a/providers/implementations/include/prov/names.h b/providers/implementations/include/prov/names.h index 4728fe903f0..06b9b5ff056 100644 --- a/providers/implementations/include/prov/names.h +++ b/providers/implementations/include/prov/names.h @@ -384,5 +384,9 @@ #define PROV_DESCS_RSA_PSS "OpenSSL RSA-PSS implementation" #define PROV_NAMES_SM2 "SM2:1.2.156.10197.1.301" #define PROV_DESCS_SM2 "OpenSSL SM2 implementation" +#define PROV_NAMES_ML_DSA_44 "ML-DSA-44:2.16.840.1.101.3.4.3.17" +#define PROV_DESCS_ML_DSA_44 "OpenSSL ML-DSA-44 implementation" #define PROV_NAMES_ML_DSA_65 "ML-DSA-65:2.16.840.1.101.3.4.3.18" #define PROV_DESCS_ML_DSA_65 "OpenSSL ML-DSA-65 implementation" +#define PROV_NAMES_ML_DSA_87 "ML-DSA-87:2.16.840.1.101.3.4.3.19" +#define PROV_DESCS_ML_DSA_87 "OpenSSL ML-DSA-87 implementation" diff --git a/providers/implementations/keymgmt/ml_dsa_kmgmt.c b/providers/implementations/keymgmt/ml_dsa_kmgmt.c index 079e67d8ce9..8890c5de4d7 100644 --- a/providers/implementations/keymgmt/ml_dsa_kmgmt.c +++ b/providers/implementations/keymgmt/ml_dsa_kmgmt.c @@ -372,4 +372,6 @@ static void ml_dsa_gen_cleanup(void *genctx) OSSL_DISPATCH_END \ } +MAKE_KEYMGMT_FUNCTIONS("ML-DSA-44", 44); MAKE_KEYMGMT_FUNCTIONS("ML-DSA-65", 65); +MAKE_KEYMGMT_FUNCTIONS("ML-DSA-87", 87); diff --git a/providers/implementations/signature/ml_dsa_sig.c b/providers/implementations/signature/ml_dsa_sig.c index d6eadcbb70d..b17a37f9574 100644 --- a/providers/implementations/signature/ml_dsa_sig.c +++ b/providers/implementations/signature/ml_dsa_sig.c @@ -52,7 +52,6 @@ static void ml_dsa_freectx(void *vctx) OPENSSL_free(ctx->propq); ossl_ml_dsa_ctx_free(ctx->ctx); - ossl_ml_dsa_key_free(ctx->key); OPENSSL_cleanse(ctx->test_entropy, ctx->test_entropy_len); OPENSSL_free(ctx); } @@ -101,9 +100,6 @@ static int ml_dsa_signverify_msg_init(void *vctx, void *vkey, if (key != NULL) { if (!ossl_ml_dsa_key_type_matches(ctx->ctx, key)) return 0; - if (!ossl_ml_dsa_key_up_ref(vkey)) - return 0; - ossl_ml_dsa_key_free(ctx->key); ctx->key = vkey; } @@ -244,4 +240,6 @@ static const OSSL_PARAM *ml_dsa_settable_ctx_params(void *vctx, OSSL_DISPATCH_END \ } +MAKE_SIGNATURE_FUNCTIONS("ML-DSA-44", 44); MAKE_SIGNATURE_FUNCTIONS("ML-DSA-65", 65); +MAKE_SIGNATURE_FUNCTIONS("ML-DSA-87", 87); diff --git a/test/ml_dsa_test.c b/test/ml_dsa_test.c index f6ec12d0498..db71ed843f9 100644 --- a/test/ml_dsa_test.c +++ b/test/ml_dsa_test.c @@ -107,7 +107,7 @@ err: static int ml_dsa_siggen_test(int tst_id) { int ret = 0; - ML_DSA_SIG_TEST_DATA *td = &ml_dsa_siggen_testdata[tst_id]; + const ML_DSA_SIG_TEST_DATA *td = &ml_dsa_siggen_testdata[tst_id]; EVP_PKEY_CTX *sctx = NULL; EVP_PKEY *pkey = NULL; EVP_SIGNATURE *sig_alg = NULL;