From: Viktor Dukhovni Date: Sat, 21 Dec 2024 16:07:33 +0000 (+1100) Subject: ML-KEM libcrypto implementation polish X-Git-Tag: openssl-3.5.0-alpha1~548 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=653fc2189dc913404203edb674ce4bb45fe5e177;p=thirdparty%2Fopenssl.git ML-KEM libcrypto implementation polish * Core ML_KEM constants in new * Renamed variant ordinals to ML_KEM__VARIANT, freeing up the unadorned ML_KEM_ names. * Fewer/cleaner macros in * Fewer/cleaner macros for setting up the ML_KEM_VINFO table. * Made (d, z) be separate inputs to the now single key generation function. Both or neither have to be NULL. This supports potential future callers that store them in a different order, or in separate buffers. - Random values are chosen when both are NULL, we never return the generated seeds, rather we may, when/if (d, z) private key support is added, store these in the expanded key, and make them available for import/export. * No need for a stand-by keygen encoded public key buffer when the caller does not provide one (will ask for it later if needed). New `hash_h_pubkey` function can compute the public hash from the expanded form in constant space (384 bytes for 12-bit encoded scalar). * Simplified code in `scalar_mult`. * New `scalar_mult_add` adds the product to an existing scalar. Used in new `matrix_mult_transpose_add` replacing `matrix_mult_transpose`. * Unrolled loop in `encode_12`. * Folded decompression and inverse NTT into vecode_decode, the three were always used together. * Folded inverse NTT into former `matrix_mult` as `matrix_mult_intt`, always used together. * New gencbd_vector_ntt combines CBD vector generation with inverse NTT in one pass. * All this makes for more readable code in `decrypt_cpa` and especially `genkey()`, which no longer requires caller-allocated variant-specific temporary storage (just a single EVP_MD_CTX is still needed). Reviewed-by: Tim Hudson Reviewed-by: Matt Caswell Reviewed-by: Neil Horman (Merged from https://github.com/openssl/openssl/pull/26236) --- diff --git a/crypto/ml_kem/ml_kem.c b/crypto/ml_kem/ml_kem.c index ca75cbe7b84..df30c26d447 100644 --- a/crypto/ml_kem/ml_kem.c +++ b/crypto/ml_kem/ml_kem.c @@ -30,10 +30,17 @@ #define bit0(b) ((b) & 1) #define bitn(n, b) (((b) >> n) & 1) -#define DEGREE ML_KEM_DEGREE -#define BARRETT_SHIFT (2 * ML_KEM_LOG2PRIME) +/* + * 12 bits are sufficient to losslessly represent values in [0, q-1]. + * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT. + */ +#define DEGREE ML_KEM_DEGREE +#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13) +#define LOG2PRIME 12 +#define BARRETT_SHIFT (2 * LOG2PRIME) + #ifdef SHA3_BLOCKSIZE -# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128) +# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128) #endif /* @@ -82,12 +89,12 @@ typedef struct ossl_ml_kem_scalar_st { uint16_t c[ML_KEM_DEGREE]; } scalar; -/* General form of public and private key storage */ +/* Key material allocation layout */ #define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \ - struct ossl_ml_kem_##name##_st { \ + struct name##_alloc { \ /* Public vector |t| */ \ scalar tbuf[(rank)]; \ - /* Pre-computed matrix |m| */ \ + /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \ scalar mbuf[(rank)*(rank)] \ /* optional private key data */ \ private_sz \ @@ -95,10 +102,10 @@ typedef struct ossl_ml_kem_scalar_st { /* Declare variant-specific public and private storage */ #define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \ - DECLARE_ML_KEM_KEYDATA(bits##_puballoc, ML_KEM_##bits##_RANK,;); \ - DECLARE_ML_KEM_KEYDATA(bits##_prvalloc, ML_KEM_##bits##_RANK,;\ + DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \ + DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\ scalar sbuf[ML_KEM_##bits##_RANK]; \ - uint8_t zbuf[ML_KEM_RANDOM_BYTES];) + uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];) DECLARE_ML_KEM_VARIANT_KEYDATA(512); DECLARE_ML_KEM_VARIANT_KEYDATA(768); DECLARE_ML_KEM_VARIANT_KEYDATA(1024); @@ -108,31 +115,20 @@ DECLARE_ML_KEM_VARIANT_KEYDATA(1024); typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], EVP_MD_CTX *mdctx, const ML_KEM_KEY *key); - -/* The wire form of a losslessly encoded vector (12-bits per element). */ -#define ML_KEM_VECTOR_BYTES(rank) \ - ((3 * ML_KEM_DEGREE / 2) * (rank)) - -/* - * The expanded internal form stores each coefficient as a 16-bit unsigned int - */ -#define ML_KEM_VECALLOC_BYTES(rank) \ - (2 * ML_KEM_DEGREE * (rank)) - -/* - * The wire-form public key consists of the lossless encoding of the vector - * "t" = "A" * "s" + "e", followed by public seed "rho". - */ -#define ML_KEM_PUBKEY_BYTES(rank) \ - (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES) +static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s); /* - * Our internal serialised private key concatenates serialisations of "s", the - * public key, the public key hash, and the failure secret "z". + * The wire-form of a losslessly encoded vector uses 12-bits per element. + * + * The wire-form public key consists of the lossless encoding of the public + * vector |t|, followed by the public seed |rho|. + * + * Our serialised private key concatenates serialisations of the private vector + * |s|, the public key, the public key hash, and the failure secret |z|. */ -#define ML_KEM_PRVKEY_BYTES(rank) \ - (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \ - + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES) +#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK) +#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES) +#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES) /* * Encapsulation produces a vector "u" and a scalar "v", whose coordinates @@ -140,57 +136,9 @@ int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], * "dv" bits, respectively. This encoding is the ciphertext input for * decapsulation. */ -#define ML_KEM_U_VECTOR_BYTES(rank, du) \ - ((ML_KEM_DEGREE / 8) * (du) * (rank)) -#define ML_KEM_V_SCALAR_BYTES(dv) \ - ((ML_KEM_DEGREE / 8) * (dv)) -#define ML_KEM_CTEXT_BYTES(rank, du, dv) \ - (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv)) - -/* - * Variant-specific sizes - */ -#define ML_KEM_512_VECTOR_BYTES \ - ML_KEM_VECTOR_BYTES(ML_KEM_512_RANK) -#define ML_KEM_768_VECTOR_BYTES \ - ML_KEM_VECTOR_BYTES(ML_KEM_768_RANK) -#define ML_KEM_1024_VECTOR_BYTES \ - ML_KEM_VECTOR_BYTES(ML_KEM_1024_RANK) - -#define ML_KEM_512_PUBLIC_KEY_BYTES \ - ML_KEM_PUBKEY_BYTES(ML_KEM_512_RANK) -#define ML_KEM_768_PUBLIC_KEY_BYTES \ - ML_KEM_PUBKEY_BYTES(ML_KEM_768_RANK) -#define ML_KEM_1024_PUBLIC_KEY_BYTES \ - ML_KEM_PUBKEY_BYTES(ML_KEM_1024_RANK) - -#define ML_KEM_512_PRIVATE_KEY_BYTES \ - ML_KEM_PRVKEY_BYTES(ML_KEM_512_RANK) -#define ML_KEM_768_PRIVATE_KEY_BYTES \ - ML_KEM_PRVKEY_BYTES(ML_KEM_768_RANK) -#define ML_KEM_1024_PRIVATE_KEY_BYTES \ - ML_KEM_PRVKEY_BYTES(ML_KEM_1024_RANK) - -#define ML_KEM_512_U_VECTOR_BYTES \ - ML_KEM_U_VECTOR_BYTES(ML_KEM_512_RANK, ML_KEM_512_DU) -#define ML_KEM_768_U_VECTOR_BYTES \ - ML_KEM_U_VECTOR_BYTES(ML_KEM_768_RANK, ML_KEM_768_DU) -#define ML_KEM_1024_U_VECTOR_BYTES \ - ML_KEM_U_VECTOR_BYTES(ML_KEM_1024_RANK, ML_KEM_1024_DU) - -#define ML_KEM_512_V_SCALAR_BYTES \ - ML_KEM_V_SCALAR_BYTES(ML_KEM_512_DV) -#define ML_KEM_768_V_SCALAR_BYTES \ - ML_KEM_V_SCALAR_BYTES(ML_KEM_768_DV) -#define ML_KEM_1024_V_SCALAR_BYTES \ - ML_KEM_V_SCALAR_BYTES(ML_KEM_1024_DV) - -#define ML_KEM_512_CIPHERTEXT_BYTES \ - (ML_KEM_512_U_VECTOR_BYTES + ML_KEM_512_V_SCALAR_BYTES) -#define ML_KEM_768_CIPHERTEXT_BYTES \ - (ML_KEM_768_U_VECTOR_BYTES + ML_KEM_768_V_SCALAR_BYTES) -#define ML_KEM_1024_CIPHERTEXT_BYTES \ - (ML_KEM_1024_U_VECTOR_BYTES + ML_KEM_1024_V_SCALAR_BYTES) +#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK) +#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV) +#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b)) /* * Per-variant fixed parameters @@ -198,51 +146,51 @@ int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], static const ML_KEM_VINFO vinfo_map[3] = { { "ML-KEM-512", - ML_KEM_512_VECTOR_BYTES, - ML_KEM_512_PRIVATE_KEY_BYTES, - ML_KEM_512_PUBLIC_KEY_BYTES, - ML_KEM_512_CIPHERTEXT_BYTES, - ML_KEM_512_U_VECTOR_BYTES, - sizeof(struct ossl_ml_kem_512_puballoc_st), - sizeof(struct ossl_ml_kem_512_prvalloc_st), - ML_KEM_512, - 512, + PRVKEY_BYTES(512), + sizeof(struct prvkey_512_alloc), + PUBKEY_BYTES(512), + sizeof(struct pubkey_512_alloc), + CTEXT_BYTES(512), + VECTOR_BYTES(512), + U_VECTOR_BYTES(512), + ML_KEM_512_VARIANT, + ML_KEM_512_BITS, ML_KEM_512_RANK, ML_KEM_512_DU, ML_KEM_512_DV, - ML_KEM_512_RNGSEC + ML_KEM_512_SECBITS }, { "ML-KEM-768", - ML_KEM_768_VECTOR_BYTES, - ML_KEM_768_PRIVATE_KEY_BYTES, - ML_KEM_768_PUBLIC_KEY_BYTES, - ML_KEM_768_CIPHERTEXT_BYTES, - ML_KEM_768_U_VECTOR_BYTES, - sizeof(struct ossl_ml_kem_768_puballoc_st), - sizeof(struct ossl_ml_kem_768_prvalloc_st), - ML_KEM_768, - 768, + PRVKEY_BYTES(768), + sizeof(struct prvkey_768_alloc), + PUBKEY_BYTES(768), + sizeof(struct pubkey_768_alloc), + CTEXT_BYTES(768), + VECTOR_BYTES(768), + U_VECTOR_BYTES(768), + ML_KEM_768_VARIANT, + ML_KEM_768_BITS, ML_KEM_768_RANK, ML_KEM_768_DU, ML_KEM_768_DV, - ML_KEM_768_RNGSEC + ML_KEM_768_SECBITS }, { "ML-KEM-1024", - ML_KEM_1024_VECTOR_BYTES, - ML_KEM_1024_PRIVATE_KEY_BYTES, - ML_KEM_1024_PUBLIC_KEY_BYTES, - ML_KEM_1024_CIPHERTEXT_BYTES, - ML_KEM_1024_U_VECTOR_BYTES, - sizeof(struct ossl_ml_kem_1024_puballoc_st), - sizeof(struct ossl_ml_kem_1024_prvalloc_st), - ML_KEM_1024, - 1024, + PRVKEY_BYTES(1024), + sizeof(struct prvkey_1024_alloc), + PUBKEY_BYTES(1024), + sizeof(struct pubkey_1024_alloc), + CTEXT_BYTES(1024), + VECTOR_BYTES(1024), + U_VECTOR_BYTES(1024), + ML_KEM_1024_VARIANT, + ML_KEM_1024_BITS, ML_KEM_1024_RANK, ML_KEM_1024_DU, ML_KEM_1024_DV, - ML_KEM_1024_RNGSEC + ML_KEM_1024_SECBITS } }; @@ -255,7 +203,7 @@ static const int kPrime = ML_KEM_PRIME; static const unsigned int kBarrettShift = BARRETT_SHIFT; static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME; static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2; -static const uint16_t kInverseDegree = ML_KEM_INVERSE_DEGREE; +static const uint16_t kInverseDegree = INVERSE_DEGREE; /* * Python helper: @@ -366,6 +314,32 @@ int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len, && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx); } +/* Incremental hash_h of expanded public key */ +static int +hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES], + EVP_MD_CTX *mdctx, ML_KEM_KEY *key) +{ + const ML_KEM_VINFO *vinfo = key->vinfo; + const scalar *t = key->t, *end = t + vinfo->rank; + unsigned int sz; + + if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)) + return 0; + + while (t < end) { + uint8_t buf[3 * DEGREE / 2]; + + scalar_encode_12(buf, t++); + if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf))) + return 0; + } + + if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES)) + return 0; + return EVP_DigestFinal_ex(mdctx, pkhash, &sz) + && ossl_assert(sz == ML_KEM_PKHASH_BYTES); +} + /* * FIPS 203, Section 4.1, equation (4.5): G. SHA3-512 hash of a variable * length input, producing 64 bytes of output, in particular the seeds @@ -579,29 +553,39 @@ static void scalar_sub(scalar *lhs, const scalar *rhs) static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) { - int i; - uint32_t real_real, img_img, real_img, img_real; - - for (i = 0; i < DEGREE / 2; i++) { - real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; - img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1]; - real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; - img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; - out->c[2 * i] = - reduce(real_real + - (uint32_t)reduce(img_img) * kModRoots[i]); - out->c[2 * i + 1] = reduce(img_real + real_img); + uint16_t *curr = out->c, *end = curr + DEGREE; + const uint16_t *lc = lhs->c, *rc = rhs->c; + const uint16_t *roots = kModRoots; + + while (curr < end) { + uint32_t l0 = *lc++, r0 = *rc++; + uint32_t l1 = *lc++, r1 = *rc++; + uint32_t zetapow = *roots++; + + *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow); + *curr++ = reduce(l0 * r1 + l1 * r0); } } +/* Above, but add the result to an existing scalar */ static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs, const scalar *rhs) { - scalar product; + uint16_t *curr = out->c, *end = curr + DEGREE; + const uint16_t *lc = lhs->c, *rc = rhs->c; + const uint16_t *roots = kModRoots; - scalar_mult(&product, lhs, rhs); - scalar_add(out, &product); + while (curr < end) { + uint32_t l0 = *lc++, r0 = *rc++; + uint32_t l1 = *lc++, r1 = *rc++; + uint16_t *c0 = curr++; + uint16_t *c1 = curr++; + uint32_t zetapow = *roots++; + + *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow); + *c1 = reduce(*c1 + l0 * r1 + l1 * r0); + } } static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, @@ -648,17 +632,23 @@ static void scalar_encode(uint8_t *out, const scalar *s, int bits) */ static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s) { - const uint16_t *c = s->c; - int i; + const uint16_t *curr = s->c, *end = curr + DEGREE; + uint16_t c1, c2; - for (i = 0; i < DEGREE / 2; ++i) { - uint16_t c1 = *c++; - uint16_t c2 = *c++; +#define encode_two \ + c1 = *curr++; \ + c2 = *curr++; \ + *out++ = (uint8_t) c1; \ + *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4)); \ + *out++ = (uint8_t) (c2 >> 4) - *out++ = (uint8_t) c1; - *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4)); - *out++ = (uint8_t) (c2 >> 4); + while (curr < end) { + encode_two; + encode_two; + encode_two; + encode_two; } +#undef encode_two } /* @@ -885,20 +875,25 @@ static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank) } /* - * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns one on - * success or zero if any parsed value is >= |ML_KEM_PRIME|. + * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early + * if any parsed value is >= |ML_KEM_PRIME|. The resulting scalars are + * then decompressed and transformed via the NTT. * * Note: Used only in decrypt_cpa(), which returns void and so does not check * the return value of this function. Side-channels are fine when the input * ciphertext to decap() is simply syntactically invalid. */ -static void vector_decode(scalar *out, const uint8_t *in, int bits, int rank) +static void +vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank) { int stride = bits * DEGREE / 8; - for (; rank-- > 0; in += stride) - if (!scalar_decode(out++, in, bits)) + for (; rank-- > 0; in += stride, ++out) { + if (!scalar_decode(out, in, bits)) return; + scalar_decompress(out, bits); + scalar_ntt(out); + } } /* vector_encode(), specialised to bits == 12. */ @@ -930,13 +925,6 @@ static void vector_compress(scalar *a, int bits, int rank) scalar_compress(a++, bits); } -/* In-place decompression of each scalar component */ -static void vector_decompress(scalar *a, int bits, int rank) -{ - while (rank-- > 0) - scalar_decompress(a++, bits); -} - /* The output scalar must not overlap with the inputs */ static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs, int rank) @@ -946,23 +934,12 @@ static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs, scalar_mult_add(out, ++lhs, ++rhs); } -/* In-place NTT transform of a vector */ -static void vector_ntt(scalar *a, int rank) -{ - while (rank-- > 0) - scalar_ntt(a++); -} - -/* In-place inverse NTT transform of a vector */ -static void vector_inverse_ntt(scalar *a, int rank) -{ - while (rank-- > 0) - scalar_inverse_ntt(a++); -} - -/* Here, the output vector must not overlap with the inputs */ +/* + * Here, the output vector must not overlap with the inputs, the result is + * directly subjected to inverse NTT. + */ static void -matrix_mult(scalar *out, const scalar *m, const scalar *a, int rank) +matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank) { const scalar *ar; int i, j; @@ -971,18 +948,19 @@ matrix_mult(scalar *out, const scalar *m, const scalar *a, int rank) scalar_mult(out, m++, ar = a); for (j = rank - 1; j > 0; --j) scalar_mult_add(out, m++, ++ar); + scalar_inverse_ntt(out); } } /* Here, the output vector must not overlap with the inputs */ static void -matrix_mult_transpose(scalar *out, const scalar *m, const scalar *a, int rank) +matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank) { const scalar *mc = m, *mr, *ar; int i, j; - for (i = rank; i-- > 0; ++out) { - scalar_mult(out, mr = mc++, ar = a); + for (i = rank; i-- > 0; ++out) { + scalar_mult_add(out, mr = mc++, ar = a); for (j = rank; --j > 0; ) scalar_mult_add(out, (mr += rank), ++ar); } @@ -1147,8 +1125,28 @@ int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter, return 1; } +/* + * As above plus NTT transform. + */ +static __owur +int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter, + const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank, + EVP_MD_CTX *mdctx, const ML_KEM_KEY *key) +{ + uint8_t input[ML_KEM_RANDOM_BYTES + 1]; + + memcpy(input, seed, ML_KEM_RANDOM_BYTES); + while (rank-- > 0) { + input[ML_KEM_RANDOM_BYTES] = (*counter)++; + if (!cbd(out, input, mdctx, key)) + return 0; + scalar_ntt(out++); + } + return 1; +} + /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */ -static CBD_FUNC const cbd1[ML_KEM_1024 + 1] = { cbd_3, cbd_2, cbd_2 }; +static CBD_FUNC const cbd1[ML_KEM_1024_VARIANT + 1] = { cbd_3, cbd_2, cbd_2 }; /* * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt. @@ -1186,15 +1184,13 @@ int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES], int dv = vinfo->dv; /* FIPS 203 "y" vector */ - if (!gencbd_vector(y, cbd_1, &counter, r, rank, mdctx, key)) + if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key)) return 0; - vector_ntt(y, rank); /* FIPS 203 "v" scalar */ inner_product(&v, key->t, y, rank); scalar_inverse_ntt(&v); /* FIPS 203 "u" vector */ - matrix_mult(u, key->m, y, rank); - vector_inverse_ntt(u, rank); + matrix_mult_intt(u, key->m, y, rank); /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */ if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key)) @@ -1230,9 +1226,7 @@ decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES], int du = vinfo->du; int dv = vinfo->dv; - vector_decode(u, ctext, du, rank); - vector_decompress(u, du, rank); - vector_ntt(u, rank); + vector_decode_decompress_ntt(u, ctext, du, rank); scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv); scalar_decompress(&v, dv); inner_product(&mask, key->s, u, rank); @@ -1337,14 +1331,13 @@ static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key) * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is * inlined. * - * The caller MUST also pass a pre-allocated scratch buffer |tmp| with room for - * at least one "vector" (rank * sizeof(scalar)), and a pre-allocated digest - * context that is not shared with any concurrent computation. + * The caller MUST pass a pre-allocated digest context that is not shared with + * any concurrent computation. * - * This function outputs the serialised wire-form |ek| public key into the - * provided |pubenc| buffer, and generates the content of the |rho|, |pkhash|, - * |t|, |m|, |s| and |z| components of the private |key| (which must have - * preallocated space for these). + * This function optionally outputs the serialised wire-form |ek| public key + * into the provided |pubenc| buffer, and generates the content of the |rho|, + * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must + * have preallocated space for these). * * Keys are computed from a 32-byte random |d| plus the 1 byte rank for * domain separation. These are concatenated and hashed to produce a pair of @@ -1360,8 +1353,7 @@ static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key) static __owur int genkey(const uint8_t d[ML_KEM_RANDOM_BYTES], const uint8_t z[ML_KEM_RANDOM_BYTES], - scalar *tmp, EVP_MD_CTX *mdctx, - uint8_t *pubenc, ML_KEM_KEY *key) + EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key) { uint8_t hashed[2 * ML_KEM_RANDOM_BYTES]; const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES; @@ -1380,24 +1372,29 @@ int genkey(const uint8_t d[ML_KEM_RANDOM_BYTES], if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key)) return 0; memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES); + /* FIPS 203 |e| vector is initial value of key->t */ if (!matrix_expand(mdctx, key) - || !gencbd_vector(key->s, cbd_1, &counter, sigma, rank, mdctx, key)) - return 0; - vector_ntt(key->s, rank); - /* FIPS 203 |e| vector */ - if (!gencbd_vector(tmp, cbd_1, &counter, sigma, rank, mdctx, key)) + || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key) + || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key)) return 0; - vector_ntt(tmp, rank); - /* Fill in the public key */ - matrix_mult_transpose(key->t, key->m, key->s, rank); - vector_add(key->t, tmp, rank); - encode_pubkey(pubenc, key); - if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key)) - return 0; + /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */ + matrix_mult_transpose_add(key->t, key->m, key->s, rank); + + if (pubenc == NULL) { + /* Incremental digest of public key without in-full serialisation. */ + if (!hash_h_pubkey(key->pkhash, mdctx, key)) + return 0; + } else { + encode_pubkey(pubenc, key); + if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key)) + return 0; + } - /* Save "z" portion of seed for "implicit rejection" on failure */ + /* Save |z| portion of seed for "implicit rejection" on failure. */ memcpy(key->z, z, ML_KEM_RANDOM_BYTES); + /* Also save |d| portion, in suport of likely alternative key format. */ + memcpy(key->d = key->z + ML_KEM_RANDOM_BYTES, d, ML_KEM_RANDOM_BYTES); return 1; } @@ -1496,11 +1493,15 @@ int add_storage(scalar *p, int private, ML_KEM_KEY *key) return 0; /* A public key needs space for |t| and |m| */ key->m = (key->t = p) + rank; - /* A private key also needs space for |s| and |z| */ + /* + * A private key also needs space for |s| and |z|. + * The |z| buffer always includes additional space for |d|, but a key's |d| + * pointer is left NULL when parsed from the NIST format, which omits that + * information. Only keys generated from a (d, z) seed pair will have a + * non-NULL |d| pointer. + */ if (private) key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank)); - else - key->z = (uint8_t *)(key->s = NULL); return 1; } @@ -1514,7 +1515,7 @@ free_storage(ML_KEM_KEY *key) if (key->t == NULL) return; OPENSSL_free(key->t); - key->z = (uint8_t *)(key->s = key->m = key->t = NULL); + key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL); } /* @@ -1527,7 +1528,7 @@ free_storage(ML_KEM_KEY *key) /* Retrieve the parameters of one of the ML-KEM variants */ const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant) { - if (variant > ML_KEM_1024) + if (variant > ML_KEM_1024_VARIANT) return NULL; return &vinfo_map[variant]; } @@ -1550,7 +1551,7 @@ ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties, key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties); key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties); key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties); - key->z = (uint8_t *)(key->s = key->m = key->t = NULL); + key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL); if (key->shake128_md != NULL && key->shake256_md != NULL @@ -1587,6 +1588,9 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection) break; case OSSL_KEYMGMT_SELECT_PRIVATE_KEY: ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret); + /* Duplicated keys retain |d|, if available */ + if (key->d != NULL) + ret->d = ret->z + ML_KEM_RANDOM_BYTES; break; } @@ -1616,10 +1620,11 @@ void ossl_ml_kem_key_free(ML_KEM_KEY *key) /*- * Cleanse any sensitive data: * - The private vector |s| is immediately followed by the FO failure - * secret |z|, we can cleanse both in one call. + * secret |z|, and seed |d|, we can cleanse all three in one call. */ if (key->s != NULL) - OPENSSL_cleanse(key->s, key->vinfo->vector_bytes + ML_KEM_RANDOM_BYTES); + OPENSSL_cleanse(key->s, + key->vinfo->vector_bytes + 2 * ML_KEM_RANDOM_BYTES); /* Free the key material */ OPENSSL_free(key->t); @@ -1673,7 +1678,7 @@ int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key) return ret; } -/* Parse input as a new private key */ +/* Parse input as a new private key */ int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len, ML_KEM_KEY *key) { @@ -1700,13 +1705,14 @@ int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len, } /* - * Generate a new keypair from a given seed, giving a deterministic result for - * running tests. The caller can elect to not collect the encoded public key. + * Generate a new keypair either from the input seeds (when non-null), yielding + * a deterministic result for running tests, or securely generated random data. */ -int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seedlen, - uint8_t *pubenc, size_t publen, - ML_KEM_KEY *key) +int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z, + uint8_t *pubenc, size_t publen, + ML_KEM_KEY *key) { + uint8_t tmpseed[2 * ML_KEM_RANDOM_BYTES]; EVP_MD_CTX *mdctx = NULL; const ML_KEM_VINFO *vinfo; int ret = 0; @@ -1715,74 +1721,29 @@ int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seedlen, return 0; vinfo = key->vinfo; - if (seed == NULL || seedlen != ML_KEM_SEED_BYTES + /* Both seeds or neither must be NULL */ + if (((d == NULL) ^ (z == NULL)) != 0 || (pubenc != NULL && publen != vinfo->pubkey_bytes) || (mdctx = EVP_MD_CTX_new()) == NULL) return 0; - if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key)) { - const uint8_t *d = seed; - const uint8_t *z = seed + ML_KEM_RANDOM_BYTES; - - /*- - * This avoids the need to handle allocation failures for two (max 2KB - * each) vectors and (if the caller does not want the public key) an - * encoded public key (max 1568 bytes), that are never retained on - * return from this function. - * We stack-allocate these. - */ -# define case_genkey_seed(bits) \ - case ML_KEM_##bits: \ - if (pubenc != NULL) \ - { \ - scalar tmp[ML_KEM_##bits##_RANK]; \ - \ - ret = genkey(d, z, tmp, mdctx, pubenc, key); \ - } else { \ - scalar tmp[ML_KEM_##bits##_RANK]; \ - uint8_t encbuf[ML_KEM_##bits##_PUBLIC_KEY_BYTES]; \ - \ - ret = genkey(d, z, tmp, mdctx, encbuf, key); \ - } \ - break - switch (vinfo->variant) { - case_genkey_seed(512); - case_genkey_seed(768); - case_genkey_seed(1024); - } -# undef case_genkey_seed + if (d == NULL) { + if (RAND_priv_bytes_ex(key->libctx, tmpseed, 2 * ML_KEM_RANDOM_BYTES, + key->vinfo->secbits) <= 0) + return 0; + d = tmpseed; + z = tmpseed + ML_KEM_RANDOM_BYTES; } + if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key)) + ret = genkey(d, z, mdctx, pubenc, key); + if (!ret) free_storage(key); EVP_MD_CTX_free(mdctx); return ret; } -/* - * Generate a new keypair from a random seed, using the library context's - * private DRBG. The caller can elect to not collect the seed or the encoded - * public key. - */ -int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen, - uint8_t *pubenc, size_t publen, - ML_KEM_KEY *key) -{ - uint8_t tmpseed[ML_KEM_SEED_BYTES]; - uint8_t *sptr = seed == NULL ? tmpseed : seed; - - if (key == NULL - || ossl_ml_kem_have_pubkey(key) - || (seed != NULL && seedlen != sizeof(tmpseed))) - return 0; - - if (RAND_priv_bytes_ex(key->libctx, sptr, sizeof(tmpseed), - key->vinfo->secbits) <= 0) - return 0; - - return ossl_ml_kem_genkey_seed(sptr, sizeof(tmpseed), pubenc, publen, key); -} - /* * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal * This is the deterministic version with randomness supplied externally. @@ -1812,7 +1773,7 @@ int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen, * We stack-allocate these. */ # define case_encap_seed(bits) \ - case ML_KEM_##bits: \ + case ML_KEM_##bits##_VARIANT: \ { \ scalar tmp[2 * ML_KEM_##bits##_RANK]; \ \ @@ -1874,15 +1835,15 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen, * retained on return from this function. * We stack-allocate these. */ -# define case_decap(bits) \ - case ML_KEM_##bits: \ - { \ - uint8_t cbuf[ML_KEM_##bits##_CIPHERTEXT_BYTES]; \ - scalar tmp[2 * ML_KEM_##bits##_RANK]; \ - \ - ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \ - EVP_MD_CTX_free(mdctx); \ - return ret; \ +# define case_decap(bits) \ + case ML_KEM_##bits##_VARIANT: \ + { \ + uint8_t cbuf[CTEXT_BYTES(bits)]; \ + scalar tmp[2 * ML_KEM_##bits##_RANK]; \ + \ + ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \ + EVP_MD_CTX_free(mdctx); \ + return ret; \ } switch (vinfo->variant) { case_decap(512); diff --git a/include/crypto/ml_kem.h b/include/crypto/ml_kem.h index e4d6dffac46..0604b967633 100644 --- a/include/crypto/ml_kem.h +++ b/include/crypto/ml_kem.h @@ -1,6 +1,6 @@ /* * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. - + * * Licensed under the Apache License 2.0 (the "License"). You may not use * this file except in compliance with the License. You can obtain a copy * in the file LICENSE in the source distribution or at @@ -16,7 +16,6 @@ # include # include # include -# include # define ML_KEM_DEGREE 256 /* @@ -24,13 +23,8 @@ * of unity, the polynomial (X^256+1) splits in Z_q[X] into 128 irreducible * quadratic factors of the form (X^2 - zeta^(2i + 1)). This is used to * implement efficient multiplication in the ring R_q via the "NTT" transform. - * - * 12 bits are sufficient to losslessly represent values in [0, q-1]. - * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT. */ # define ML_KEM_PRIME (ML_KEM_DEGREE * 13 + 1) -# define ML_KEM_LOG2PRIME 12 -# define ML_KEM_INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13) /* * Various ML-KEM primitives need random input, 32-bytes at a time. Key @@ -89,87 +83,51 @@ * - "secbits" is required security strength of the RNG for the random inputs. */ -/* - * The wire form of a losslessly encoded vector (12-bits per element) - */ -# define ML_KEM_VECTOR_BYTES(rank) \ - ((3 * ML_KEM_DEGREE / 2) * (rank)) - -/* - * The wire-form public key consists of the lossless encoding of the vector - * "t" = "A" * "s" + "e", followed by public seed "rho". - */ -# define ML_KEM_PUBKEY_BYTES(rank) \ - (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES) - -/* - * Our internal serialised private key concatenates serialisations of "s", the - * public key, the public key hash, and the failure secret "z". - */ -# define ML_KEM_PRVKEY_BYTES(rank) \ - (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \ - + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES) - -/* - * Encapsulation produces a vector "u" and a scalar "v", whose coordinates - * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and - * "dv" bits, respectively. This encoding is the ciphertext input for - * decapsulation. - */ -# define ML_KEM_U_VECTOR_BYTES(rank, du) \ - ((ML_KEM_DEGREE / 8) * (du) * (rank)) -# define ML_KEM_V_SCALAR_BYTES(dv) \ - ((ML_KEM_DEGREE / 8) * (dv)) -# define ML_KEM_CTEXT_BYTES(rank, du, dv) \ - (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv)) - /* * Variant-specific constants and structures * ----------------------------------------- */ -# define ML_KEM_512_RANK 2 -# define ML_KEM_512_ETA1 3 -# define ML_KEM_512_ETA2 2 -# define ML_KEM_512_DU 10 -# define ML_KEM_512_DV 4 -# define ML_KEM_512_RNGSEC 128 - -# define ML_KEM_768_RANK 3 -# define ML_KEM_768_ETA1 2 -# define ML_KEM_768_ETA2 2 -# define ML_KEM_768_DU 10 -# define ML_KEM_768_DV 4 -# define ML_KEM_768_RNGSEC 192 - -# define ML_KEM_1024_RANK 4 -# define ML_KEM_1024_ETA1 2 -# define ML_KEM_1024_ETA2 2 -# define ML_KEM_1024_DU 11 -# define ML_KEM_1024_DV 5 -# define ML_KEM_1024_RNGSEC 256 +# define ML_KEM_512_VARIANT 0 +# define ML_KEM_512_BITS 512 +# define ML_KEM_512_RANK 2 +# define ML_KEM_512_ETA1 3 +# define ML_KEM_512_ETA2 2 +# define ML_KEM_512_DU 10 +# define ML_KEM_512_DV 4 +# define ML_KEM_512_SECBITS 128 + +# define ML_KEM_768_VARIANT 1 +# define ML_KEM_768_BITS 768 +# define ML_KEM_768_RANK 3 +# define ML_KEM_768_ETA1 2 +# define ML_KEM_768_ETA2 2 +# define ML_KEM_768_DU 10 +# define ML_KEM_768_DV 4 +# define ML_KEM_768_SECBITS 192 + +# define ML_KEM_1024_VARIANT 2 +# define ML_KEM_1024_BITS 1024 +# define ML_KEM_1024_RANK 4 +# define ML_KEM_1024_ETA1 2 +# define ML_KEM_1024_ETA2 2 +# define ML_KEM_1024_DU 11 +# define ML_KEM_1024_DV 5 +# define ML_KEM_1024_SECBITS 256 /* * External variant-specific API * ----------------------------- */ -/* - * Each variant parameter set is associated with an ordinal number which - * represents that parameter set. - */ -# define ML_KEM_512 0 -# define ML_KEM_768 1 -# define ML_KEM_1024 2 - typedef struct { const char *algorithm_name; - size_t vector_bytes; size_t prvkey_bytes; + size_t prvalloc; size_t pubkey_bytes; + size_t puballoc; size_t ctext_bytes; + size_t vector_bytes; size_t u_vector_bytes; - size_t puballoc; - size_t prvalloc; int variant; int bits; int rank; @@ -182,7 +140,7 @@ typedef struct { const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant); /* Known as ML_KEM_KEY via crypto/types.h */ -struct ossl_ml_kem_key_st { +typedef struct ossl_ml_kem_key_st { /* Variant metadata, for one of ML-KEM-{512,768,1024} */ const ML_KEM_VINFO *vinfo; @@ -207,11 +165,12 @@ struct ossl_ml_kem_key_st { struct ossl_ml_kem_scalar_st *m; /* Pre-computed pubkey matrix */ struct ossl_ml_kem_scalar_st *s; /* Private key secret vector */ uint8_t *z; /* Private key FO failure secret */ + uint8_t *d; /* Private key seed */ /* Fixed-size/offset built-ins */ uint8_t rho[ML_KEM_RANDOM_BYTES]; /* Matrix recovery seed */ uint8_t pkhash[ML_KEM_PKHASH_BYTES]; /* Hash of wire-form public key */ -}; +} ML_KEM_KEY; /* The public key is always present, when the private is */ # define ossl_ml_kem_key_vinfo(key) ((key)->vinfo) @@ -222,8 +181,6 @@ struct ossl_ml_kem_key_st { * ----- ML-KEM key lifecycle */ -# ifndef OPENSSL_NO_ML_KEM - /* * Allocate a "bare" key for given ML-KEM variant. Initially without any public * or private key material. @@ -257,13 +214,9 @@ __owur int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len, ML_KEM_KEY *key); __owur -int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen, - uint8_t *pubenc, size_t publen, - ML_KEM_KEY *key); -__owur -int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seed_len, - uint8_t *pubenc, size_t publen, - ML_KEM_KEY *key); +int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z, + uint8_t *pubenc, size_t publen, + ML_KEM_KEY *key); /* * Perform an ML-KEM operation with a given ML-KEM key. The key can generally @@ -295,6 +248,4 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen, __owur int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2); -# endif /* OPENSSL_NO_ML_KEM */ - #endif /* OPENSSL_HEADER_ML_KEM_H */ diff --git a/include/crypto/types.h b/include/crypto/types.h index c5e3d0effbb..ad17f052e45 100644 --- a/include/crypto/types.h +++ b/include/crypto/types.h @@ -29,8 +29,4 @@ typedef struct dsa_st DSA; typedef struct ecx_key_st ECX_KEY; # endif -# ifndef OPENSSL_NO_ML_KEM -typedef struct ossl_ml_kem_key_st ML_KEM_KEY; -# endif - #endif diff --git a/include/openssl/ml_kem.h b/include/openssl/ml_kem.h new file mode 100644 index 00000000000..2b731a534b2 --- /dev/null +++ b/include/openssl/ml_kem.h @@ -0,0 +1,31 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +#ifndef OPENSSL_ML_KEM_H +# define OPENSSL_ML_KEM_H +# pragma once + +# define OSSL_ML_KEM_SHARED_SECRET_BYTES 32 + +# define OSSL_ML_KEM_512_BITS 512 +# define OSSL_ML_KEM_512_SECURITY_BITS 128 +# define OSSL_ML_KEM_512_CIPHERTEXT_BYTES 768 +# define OSSL_ML_KEM_512_PUBLIC_KEY_BYTES 800 + +# define OSSL_ML_KEM_768_BITS 768 +# define OSSL_ML_KEM_768_SECURITY_BITS 192 +# define OSSL_ML_KEM_768_CIPHERTEXT_BYTES 1088 +# define OSSL_ML_KEM_768_PUBLIC_KEY_BYTES 1184 + +# define OSSL_ML_KEM_1024_BITS 1024 +# define OSSL_ML_KEM_1024_SECURITY_BITS 256 +# define OSSL_ML_KEM_1024_CIPHERTEXT_BYTES 1568 +# define OSSL_ML_KEM_1024_PUBLIC_KEY_BYTES 1568 + +#endif diff --git a/providers/common/capabilities.c b/providers/common/capabilities.c index 273093b6845..3c2882601db 100644 --- a/providers/common/capabilities.c +++ b/providers/common/capabilities.c @@ -102,9 +102,9 @@ static const TLS_GROUP_CONSTANTS group_list[] = { { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1, 0 }, { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1, 0 }, #ifndef OPENSSL_NO_ML_KEM - { OSSL_TLS_GROUP_ID_mlkem512, ML_KEM_512_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 }, - { OSSL_TLS_GROUP_ID_mlkem768, ML_KEM_768_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 }, - { OSSL_TLS_GROUP_ID_mlkem1024, ML_KEM_1024_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 }, + { OSSL_TLS_GROUP_ID_mlkem512, ML_KEM_512_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 }, + { OSSL_TLS_GROUP_ID_mlkem768, ML_KEM_768_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 }, + { OSSL_TLS_GROUP_ID_mlkem1024, ML_KEM_1024_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 }, #endif }; diff --git a/providers/implementations/keymgmt/ml_kem_kmgmt.c b/providers/implementations/keymgmt/ml_kem_kmgmt.c index 28d7bb1104f..cda7cec9632 100644 --- a/providers/implementations/keymgmt/ml_kem_kmgmt.c +++ b/providers/implementations/keymgmt/ml_kem_kmgmt.c @@ -442,7 +442,8 @@ static void *ml_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg) ML_KEM_KEY *key; uint8_t *nopub = NULL; uint8_t *seed = gctx->seed; - size_t slen = seed == NULL ? 0 : ML_KEM_SEED_BYTES; + uint8_t *d = seed != NULL ? seed : NULL; + uint8_t *z = seed != NULL ? seed + ML_KEM_RANDOM_BYTES : NULL; int genok = 0; if (gctx == NULL @@ -454,9 +455,7 @@ static void *ml_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg) if ((gctx->selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0) return key; - genok = seed != NULL - ? ossl_ml_kem_genkey_seed(seed, slen, nopub, 0, key) - : ossl_ml_kem_genkey_rand(NULL, slen, nopub, 0, key); + genok = ossl_ml_kem_genkey(d, z, nopub, 0, key); /* Erase the single-use seed */ if (seed != NULL) @@ -496,12 +495,12 @@ typedef void (*func_ptr_t)(void); static void *ml_kem_##bits##_new(void *provctx) \ { \ return ml_kem_new(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), \ - NULL, ML_KEM_##bits); \ + NULL, ML_KEM_##bits##_VARIANT); \ } \ static void *ml_kem_##bits##_gen_init(void *provctx, int selection, \ const OSSL_PARAM params[]) \ { \ - return ml_kem_gen_init(provctx, selection, params, ML_KEM_##bits); \ + return ml_kem_gen_init(provctx, selection, params, ML_KEM_##bits##_VARIANT); \ } \ const OSSL_DISPATCH ossl_ml_kem_##bits##_keymgmt_functions[] = { \ { OSSL_FUNC_KEYMGMT_NEW, (func_ptr_t) ml_kem_##bits##_new }, \ diff --git a/test/ml_kem_evp_extra_test.c b/test/ml_kem_evp_extra_test.c index 349755b3784..1ffeb6fc209 100644 --- a/test/ml_kem_evp_extra_test.c +++ b/test/ml_kem_evp_extra_test.c @@ -230,7 +230,7 @@ static int test_non_derandomised_ml_kem(void) if (!TEST_ptr(sha256 = EVP_MD_fetch(NULL, "sha256", NULL))) return 0; - for (i = ML_KEM_512; i < ML_KEM_1024; ++i) { + for (i = ML_KEM_512_VARIANT; i < ML_KEM_1024_VARIANT; ++i) { const ML_KEM_VINFO *v; OSSL_PARAM params[3], *p; uint8_t hash[32]; diff --git a/test/ml_kem_internal_test.c b/test/ml_kem_internal_test.c index 1505209d3ac..54440313527 100644 --- a/test/ml_kem_internal_test.c +++ b/test/ml_kem_internal_test.c @@ -107,7 +107,7 @@ static int sanity_test(void) decap_entropy = ml_kem_public_entropy + ML_KEM_RANDOM_BYTES; - for (i = ML_KEM_512; i < ML_KEM_1024; ++i) { + for (i = ML_KEM_512_VARIANT; i < ML_KEM_1024_VARIANT; ++i) { OSSL_PARAM params[3]; uint8_t hash[32]; uint8_t shared_secret[ML_KEM_SHARED_SECRET_BYTES]; @@ -145,8 +145,8 @@ static int sanity_test(void) ret2 = -2; /* Generate a private key */ - if (!ossl_ml_kem_genkey_rand(NULL, 0, encoded_public_key, - v->pubkey_bytes, private_key)) + if (!ossl_ml_kem_genkey(NULL, NULL, encoded_public_key, + v->pubkey_bytes, private_key)) goto done; /* Check that no more entropy is available! */