From: Viktor Dukhovni Date: Mon, 13 Jan 2025 17:34:37 +0000 (+1100) Subject: ML-KEM implementation cleanup/speedup X-Git-Tag: openssl-3.5.0-alpha1~526 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=003309c376a50ab08ee3a4c23c34373f69538210;p=thirdparty%2Fopenssl.git ML-KEM implementation cleanup/speedup Reviewed-by: Matt Caswell Reviewed-by: Tim Hudson Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/26341) --- diff --git a/crypto/ml_kem/ml_kem.c b/crypto/ml_kem/ml_kem.c index a81fe8df5cb..e4713add9bb 100644 --- a/crypto/ml_kem/ml_kem.c +++ b/crypto/ml_kem/ml_kem.c @@ -12,6 +12,7 @@ #include #include #include +#include #if defined(OPENSSL_CONSTANT_TIME_VALIDATION) #include @@ -117,7 +118,7 @@ DECLARE_ML_KEM_VARIANT_KEYDATA(1024); typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], EVP_MD_CTX *mdctx, const ML_KEM_KEY *key); -static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s); +static void scalar_encode(uint8_t *out, const scalar *s, int bits); /* * The wire-form of a losslessly encoded vector uses 12-bits per element. @@ -166,15 +167,6 @@ static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s); #endif -/* - * Negligible to no performance impact, used only in ciphertext decoding when - * extracting the |u| vector and |v| scalar. - */ -static ossl_inline int constant_time_declassify_int(int v) { - CONSTTIME_DECLASSIFY(&v, sizeof(v)); - return value_barrier(v); -} - /* * Per-variant fixed parameters */ @@ -188,7 +180,7 @@ static const ML_KEM_VINFO vinfo_map[3] = { CTEXT_BYTES(512), VECTOR_BYTES(512), U_VECTOR_BYTES(512), - NID_ML_KEM_512, + EVP_PKEY_ML_KEM_512, ML_KEM_512_BITS, ML_KEM_512_RANK, ML_KEM_512_DU, @@ -204,7 +196,7 @@ static const ML_KEM_VINFO vinfo_map[3] = { CTEXT_BYTES(768), VECTOR_BYTES(768), U_VECTOR_BYTES(768), - NID_ML_KEM_768, + EVP_PKEY_ML_KEM_768, ML_KEM_768_BITS, ML_KEM_768_RANK, ML_KEM_768_DU, @@ -220,7 +212,7 @@ static const ML_KEM_VINFO vinfo_map[3] = { CTEXT_BYTES(1024), VECTOR_BYTES(1024), U_VECTOR_BYTES(1024), - NID_ML_KEM_1024, + EVP_PKEY_ML_KEM_1024, ML_KEM_1024_BITS, ML_KEM_1024_RANK, ML_KEM_1024_DU, @@ -259,32 +251,47 @@ static const uint16_t kInverseDegree = INVERSE_DEGREE; * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ static const uint16_t kNTTRoots[128] = { - 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, - 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, - 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, - 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, - 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, - 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, - 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, - 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, - 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, - 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, }; -/* InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ +/* + * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] + * Listed in order of use in the inverse NTT loop (index 0 is skipped): + * + * 0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ... + */ static const uint16_t kInverseNTTRoots[128] = { - 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, - 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, - 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, - 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, - 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, - 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, - 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, - 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, - 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, - 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, - 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, + 1, 1175, 2444, 394, 1219, 2300, 1455, 2117, + 1607, 2443, 554, 1179, 2186, 2303, 2926, 2237, + 525, 735, 863, 2768, 1230, 2572, 556, 3010, + 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, + 1745, 2688, 3061, 992, 2596, 941, 892, 1021, + 2390, 642, 1868, 2377, 1482, 1540, 540, 1678, + 1626, 279, 314, 1173, 2573, 3096, 48, 667, + 1920, 2229, 1041, 2606, 1692, 680, 2746, 568, + 3312, 2419, 2102, 219, 855, 2681, 1848, 712, + 682, 927, 1795, 461, 1891, 2877, 2522, 1894, + 1010, 1414, 2009, 3296, 464, 2697, 816, 1352, + 2679, 1274, 1052, 1025, 2132, 1573, 76, 2998, + 3040, 2508, 1355, 450, 936, 447, 2794, 1235, + 1903, 1996, 1089, 3273, 283, 1853, 1990, 882, + 3033, 1583, 2760, 69, 543, 2532, 3136, 1410, + 2267, 2481, 1432, 2699, 687, 40, 749, 1600, }; /* @@ -361,13 +368,13 @@ hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES], if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)) return 0; - while (t < end) { + do { uint8_t buf[3 * DEGREE / 2]; - scalar_encode_12(buf, t++); + scalar_encode(buf, t++, 12); if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf))) return 0; - } + } while (t < end); if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES)) return 0; @@ -421,33 +428,25 @@ int sample_scalar(scalar *out, EVP_MD_CTX *mdctx) uint16_t d; uint8_t b1, b2, b3; -#define sample_scalar_do_bytes \ - b1 = *in++; \ - b2 = *in++; \ - b3 = *in++; \ - \ - if (curr >= endout) \ - break; \ - if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime) \ - *curr++ = d; \ - if (curr >= endout) \ - break; \ - if ((d = (b3 << 4) + (b2 >> 4)) < kPrime) \ - *curr++ = d - - while (curr < endout) { - if (!EVP_DigestSqueeze(mdctx, buf, sizeof(buf))) + do { + if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf))) return 0; - /* Unrolled loop: twelve bytes in, eight 12-bit *candidates* out */ - for (in = buf; in < endin;) { - sample_scalar_do_bytes; - sample_scalar_do_bytes; - sample_scalar_do_bytes; - sample_scalar_do_bytes; - } - } + do { + b1 = *in++; + b2 = *in++; + b3 = *in++; + + if (curr >= endout) + break; + if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime) + *curr++ = d; + if (curr >= endout) + break; + if ((d = (b3 << 4) + (b2 >> 4)) < kPrime) + *curr++ = d; + } while (in < endin); + } while (curr < endout); return 1; -#undef sample_scalar_do_bytes } /*- @@ -480,6 +479,17 @@ static __owur uint16_t reduce(uint32_t x) return reduce_once(remainder); } +/* Multiply a scalar by a constant. */ +static void scalar_mult_const(scalar *s, uint16_t a) +{ + uint16_t *curr = s->c, *end = curr + DEGREE, tmp; + + do { + tmp = reduce(*curr * a); + *curr++ = tmp; + } while (curr < end); +} + /*- * FIPS 203, Section 4.3, Algoritm 9: "NTT". * In-place number theoretic transform of a given scalar. Note that ML-KEM's @@ -491,29 +501,26 @@ static __owur uint16_t reduce(uint32_t x) */ static void scalar_ntt(scalar *s) { - int offset = DEGREE; - int k, step, i, j; - uint32_t step_root; - uint16_t odd, even; + const uint16_t *roots = kNTTRoots; + uint16_t *end = s->c + DEGREE; + int offset = DEGREE / 2; - /* - * `int` is used here because using `size_t` throughout caused a ~5% - * slowdown with Clang 14 on Aarch64. - */ - for (step = 1; step < DEGREE / 2; step <<= 1) { - offset >>= 1; - k = 0; - for (i = 0; i < step; i++) { - step_root = kNTTRoots[i + step]; - for (j = k; j < k + offset; j++) { - odd = reduce(step_root * s->c[j + offset]); - even = s->c[j]; - s->c[j] = reduce_once(odd + even); - s->c[j + offset] = reduce_once(even - odd + kPrime); - } - k += 2 * offset; - } - } + do { + uint16_t *curr = s->c, *peer; + + do { + uint16_t *pause = curr + offset, even, odd; + uint32_t zeta = *++roots; + + peer = pause; + do { + even = *curr; + odd = reduce(*peer * zeta); + *peer++ = reduce_once(even - odd + kPrime); + *curr++ = reduce_once(odd + even); + } while (curr < pause); + } while ((curr = peer) < end); + } while ((offset >>= 1) >= 2); } /*- @@ -523,37 +530,30 @@ static void scalar_ntt(scalar *s) * the number theoretic transform, this leaves off the first step of the normal * iFFT to account for the fact that 3329 does not have a 512th root of unity, * using the precomputed 128 roots of unity stored in InverseNTTRoots. - * - * FIPS 203, Algorithm 10, performs this transformation in a slightly different - * manner, using the same NTTRoots table as the forward NTT transform. */ static void scalar_inverse_ntt(scalar *s) { - int step = DEGREE / 2; - int offset, k, i, j; - uint32_t step_root; - uint16_t odd, even; + const uint16_t *roots = kInverseNTTRoots; + uint16_t *end = s->c + DEGREE; + int offset = 2; - /* - * `int` is used here because using `size_t` throughout caused a ~5% - * slowdown with Clang 14 on Aarch64. - */ - for (offset = 2; offset < DEGREE; offset <<= 1) { - step >>= 1; - k = 0; - for (i = 0; i < step; i++) { - step_root = kInverseNTTRoots[i + step]; - for (j = k; j < k + offset; j++) { - odd = s->c[j + offset]; - even = s->c[j]; - s->c[j] = reduce_once(odd + even); - s->c[j + offset] = reduce(step_root * (even - odd + kPrime)); - } - k += 2 * offset; - } - } - for (i = 0; i < DEGREE; i++) - s->c[i] = reduce(s->c[i] * kInverseDegree); + do { + uint16_t *curr = s->c, *peer; + + do { + uint16_t *pause = curr + offset, even, odd; + uint32_t zeta = *++roots; + + peer = pause; + do { + even = *curr; + odd = *peer; + *peer++ = reduce(zeta * (even - odd + kPrime)); + *curr++ = reduce_once(odd + even); + } while (curr < pause); + } while ((curr = peer) < end); + } while ((offset <<= 1) < DEGREE); + scalar_mult_const(s, kInverseDegree); } /* Addition updating the LHS scalar in-place. */ @@ -592,14 +592,14 @@ static void scalar_mult(scalar *out, const scalar *lhs, const uint16_t *lc = lhs->c, *rc = rhs->c; const uint16_t *roots = kModRoots; - while (curr < end) { + do { uint32_t l0 = *lc++, r0 = *rc++; uint32_t l1 = *lc++, r1 = *rc++; uint32_t zetapow = *roots++; *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow); *curr++ = reduce(l0 * r1 + l1 * r0); - } + } while (curr < end); } /* Above, but add the result to an existing scalar */ @@ -611,7 +611,7 @@ void scalar_mult_add(scalar *out, const scalar *lhs, const uint16_t *lc = lhs->c, *rc = rhs->c; const uint16_t *roots = kModRoots; - while (curr < end) { + do { uint32_t l0 = *lc++, r0 = *rc++; uint32_t l1 = *lc++, r1 = *rc++; uint16_t *c0 = curr++; @@ -620,70 +620,34 @@ void scalar_mult_add(scalar *out, const scalar *lhs, *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow); *c1 = reduce(*c1 + l0 * r1 + l1 * r0); - } + } while (curr < end); } -static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, - 0x1f, 0x3f, 0x7f, 0xff}; - /*- - * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<12. - * Here |bits| is |d|. For efficiency, we handle the d=1, and d=12 cases - * separately. + * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12. + * Here |bits| is |d|. For efficiency, we handle the d=1 case separately. */ static void scalar_encode(uint8_t *out, const scalar *s, int bits) -{ - uint8_t out_byte = 0; - int out_byte_bits = 0; - int i, element_bits_done, chunk_bits, out_bits_remaining; - uint16_t element; - - for (i = 0; i < DEGREE; i++) { - element = s->c[i]; - element_bits_done = 0; - while (element_bits_done < bits) { - chunk_bits = bits - element_bits_done; - out_bits_remaining = 8 - out_byte_bits; - if (chunk_bits >= out_bits_remaining) { - chunk_bits = out_bits_remaining; - out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits; - *out++ = out_byte; - out_byte_bits = 0; - out_byte = 0; - } else { - out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits; - out_byte_bits += chunk_bits; - } - element_bits_done += chunk_bits; - element >>= chunk_bits; - } - } - if (out_byte_bits > 0) - *out = out_byte; -} - -/* - * scalar_encode_12 is |scalar_encode| specialised for |bits| == 12. - */ -static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s) { const uint16_t *curr = s->c, *end = curr + DEGREE; - uint16_t c1, c2; - -#define encode_two \ - c1 = *curr++; \ - c2 = *curr++; \ - *out++ = (uint8_t) c1; \ - *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4)); \ - *out++ = (uint8_t) (c2 >> 4) - - while (curr < end) { - encode_two; - encode_two; - encode_two; - encode_two; - } -#undef encode_two + uint64_t accum = 0, element; + int used = 0; + + do { + element = *curr++; + if (used + bits < 64) { + accum |= element << used; + used += bits; + } else if (used + bits > 64) { + out = OPENSSL_store_u64_le(out, accum | (element << used)); + accum = element >> (64 - used); + used = (used + bits) - 64; + } else { + out = OPENSSL_store_u64_le(out, accum | (element << used)); + accum = 0; + used = 0; + } + } while (curr < end); } /* @@ -709,42 +673,51 @@ static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s) * separately. * * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in - * |out|. It returns one on success and zero if any parsed value is >= - * |kPrime|. - * - * Note: Used in decrypt_cpa(), which returns void and so does not check the - * return value of this function. But also used in vector_decode(), which - * returns early when scalar_decode() fails. + * |out|. */ -static int scalar_decode(scalar *out, const uint8_t *in, int bits) +static void scalar_decode(scalar *out, const uint8_t *in, int bits) { - uint8_t in_byte = 0; - int in_byte_bits_left = 0; - int i, element_bits_done, chunk_bits; - uint16_t element; - - for (i = 0; i < DEGREE; i++) { - element = 0; - element_bits_done = 0; - while (element_bits_done < bits) { - if (in_byte_bits_left == 0) { - in_byte = *in; - in++; - in_byte_bits_left = 8; - } - chunk_bits = bits - element_bits_done; - if (chunk_bits > in_byte_bits_left) - chunk_bits = in_byte_bits_left; - element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done; - in_byte_bits_left -= chunk_bits; - in_byte >>= chunk_bits; - element_bits_done += chunk_bits; + uint16_t *curr = out->c, *end = curr + DEGREE; + uint64_t accum = 0; + int accum_bits = 0, todo = bits; + uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask; + uint16_t element = 0; + + do { + if (accum_bits == 0) { + in = OPENSSL_load_u64_le(&accum, in); + accum_bits = 64; } - if (constant_time_declassify_int(element >= kPrime)) - return 0; - out->c[i] = element; - } - return 1; + if (todo == bits && accum_bits >= bits) { + /* No partial "element", and all the required bits available */ + *curr++ = ((uint16_t) accum) & mask; + accum >>= bits; + accum_bits -= bits; + } else if (accum_bits >= todo) { + /* A partial "element", and all the required bits available */ + *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo)); + accum >>= todo; + accum_bits -= todo; + element = 0; + todo = bits; + mask = bitmask; + } else { + /* + * Only some of the requisite bits accumulated, store |accum_bits| + * of these in |element|. The accumulated bitcount becomes 0, but + * as soon as we have more bits we'll want to merge accum_bits + * fewer of them into the final |element|. + * + * Note that with a 64-bit accumulator and |bits| always 12 or + * less, if we're here, the previous iteration had all the + * requisite bits, and so there are no kept bits in |element|. + */ + element = ((uint16_t) accum) & mask; + todo -= accum_bits; + mask = bitmask >> accum_bits; + accum_bits = 0; + } + } while (curr < end); } static __owur @@ -798,7 +771,7 @@ scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8]) b >>= 1 /* Unrolled to process each byte in one iteration */ - while (curr < end) { + do { b = *in++; decode_decompress_add_bit; decode_decompress_add_bit; @@ -809,7 +782,7 @@ scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8]) decode_decompress_add_bit; decode_decompress_add_bit; decode_decompress_add_bit; - } + } while (curr < end); #undef decode_decompress_add_bit } @@ -893,8 +866,9 @@ static void scalar_decompress(scalar *s, int bits) /* Addition updating the LHS vector in-place. */ static void vector_add(scalar *lhs, const scalar *rhs, int rank) { - while (rank-- > 0) + do { scalar_add(lhs++, rhs++); + } while (--rank > 0); } /* @@ -925,23 +899,12 @@ vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank) int stride = bits * DEGREE / 8; for (; rank-- > 0; in += stride, ++out) { - if (!scalar_decode(out, in, bits)) - return; + scalar_decode(out, in, bits); scalar_decompress(out, bits); scalar_ntt(out); } } -/* vector_encode(), specialised to bits == 12. */ -static void vector_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *a, - int rank) -{ - int stride = 3 * DEGREE / 2; - - for (; rank-- > 0; out += stride) - scalar_encode_12(out, a++); -} - /* vector_decode(), specialised to bits == 12. */ static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank) @@ -957,8 +920,9 @@ int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank) /* In-place compression of each scalar component */ static void vector_compress(scalar *a, int bits, int rank) { - while (rank-- > 0) + do { scalar_compress(a++, bits); + } while (--rank > 0); } /* The output scalar must not overlap with the inputs */ @@ -1051,35 +1015,27 @@ int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], if (!prf(randbuf, sizeof(randbuf), in, mdctx, key)) return 0; - /* - * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| for a - * discussion on why the value barrier is by default omitted. While this - * could have been written reduce_once(value + kPrime), this is one extra - * addition and small range of |value| tempts some versions of Clang to - * emit a branch. - */ -#define cbd_2_do_byte \ - b = *r++; \ - \ - value = bit0(b) + bitn(1, b); \ - value -= bitn(2, b) + bitn(3, b); \ - mask = constish_time_non_zero(value >> 15); \ - *curr++ = value + (kPrime & mask); \ - \ - value = bitn(4, b) + bitn(5, b); \ - value -= bitn(6, b) + bitn(7, b); \ - mask = constish_time_non_zero(value >> 15); \ - *curr++ = value + (kPrime & mask) - - /* Unrolled, 4 random bytes in, 8 coefficients out */ - while (curr < end) { - cbd_2_do_byte; - cbd_2_do_byte; - cbd_2_do_byte; - cbd_2_do_byte; - } + do { + b = *r++; + + /* + * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| + * for a discussion on why the value barrier is by default omitted. + * While this could have been written reduce_once(value + kPrime), this + * is one extra addition and small range of |value| tempts some + * versions of Clang to emit a branch. + */ + value = bit0(b) + bitn(1, b); + value -= bitn(2, b) + bitn(3, b); + mask = constish_time_non_zero(value >> 15); + *curr++ = value + (kPrime & mask); + + value = bitn(4, b) + bitn(5, b); + value -= bitn(6, b) + bitn(7, b); + mask = constish_time_non_zero(value >> 15); + *curr++ = value + (kPrime & mask); + } while (curr < end); return 1; -#undef cbd_2_do_byte } /* @@ -1100,45 +1056,39 @@ int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1], if (!prf(randbuf, sizeof(randbuf), in, mdctx, key)) return 0; - /* - * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| for a - * discussion on why the value barrier is by default omitted. While this - * could have been written reduce_once(value + kPrime), this is one extra - * addition and small range of |value| tempts some versions of Clang to - * emit a branch. - */ -#define cbd_3_do_bytes \ - b1 = *r++; \ - b2 = *r++; \ - b3 = *r++; \ - \ - value = bit0(b1) + bitn(1, b1) + bitn(2, b1); \ - value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1); \ - mask = constish_time_non_zero(value >> 15); \ - *curr++ = value + (kPrime & mask); \ - \ - value = bitn(6, b1) + bitn(7, b1) + bit0(b2); \ - value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2); \ - mask = constish_time_non_zero(value >> 15); \ - *curr++ = value + (kPrime & mask); \ - \ - value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2); \ - value -= bitn(7, b2) + bit0(b3) + bitn(1, b3); \ - mask = constish_time_non_zero(value >> 15); \ - *curr++ = value + (kPrime & mask); \ - \ - value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3); \ - value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3); \ - mask = constish_time_non_zero(value >> 15); \ + do { + b1 = *r++; + b2 = *r++; + b3 = *r++; + + /* + * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| + * for a discussion on why the value barrier is by default omitted. + * While this could have been written reduce_once(value + kPrime), this + * is one extra addition and small range of |value| tempts some + * versions of Clang to emit a branch. + */ + value = bit0(b1) + bitn(1, b1) + bitn(2, b1); + value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1); + mask = constish_time_non_zero(value >> 15); *curr++ = value + (kPrime & mask); - /* Unrolled, 6 random bytes in, 8 coefficients out */ - while (curr < end) { - cbd_3_do_bytes; - cbd_3_do_bytes; - } + value = bitn(6, b1) + bitn(7, b1) + bit0(b2); + value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2); + mask = constish_time_non_zero(value >> 15); + *curr++ = value + (kPrime & mask); + + value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2); + value -= bitn(7, b2) + bit0(b3) + bitn(1, b3); + mask = constish_time_non_zero(value >> 15); + *curr++ = value + (kPrime & mask); + + value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3); + value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3); + mask = constish_time_non_zero(value >> 15); + *curr++ = value + (kPrime & mask); + } while (curr < end); return 1; -#undef cbd_3_do_bytes } /* @@ -1153,11 +1103,11 @@ int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter, uint8_t input[ML_KEM_RANDOM_BYTES + 1]; memcpy(input, seed, ML_KEM_RANDOM_BYTES); - while (rank-- > 0) { + do { input[ML_KEM_RANDOM_BYTES] = (*counter)++; if (!cbd(out++, input, mdctx, key)) return 0; - } + } while (--rank > 0); return 1; } @@ -1172,17 +1122,17 @@ int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter, uint8_t input[ML_KEM_RANDOM_BYTES + 1]; memcpy(input, seed, ML_KEM_RANDOM_BYTES); - while (rank-- > 0) { + do { input[ML_KEM_RANDOM_BYTES] = (*counter)++; if (!cbd(out, input, mdctx, key)) return 0; scalar_ntt(out++); - } + } while (--rank > 0); return 1; } /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */ -#define CBD1(evp_type) ((evp_type) == NID_ML_KEM_512 ? cbd_3 : cbd_2) +#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2) /* * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt. @@ -1285,7 +1235,7 @@ static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key) const uint8_t *rho = key->rho; const ML_KEM_VINFO *vinfo = key->vinfo; - vector_encode_12(out, key->t, vinfo->rank); + vector_encode(out, key->t, 12, vinfo->rank); memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES); } @@ -1299,7 +1249,7 @@ static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key) { const ML_KEM_VINFO *vinfo = key->vinfo; - vector_encode_12(out, key->s, vinfo->rank); + vector_encode(out, key->s, 12, vinfo->rank); out += vinfo->vector_bytes; encode_pubkey(out, key); out += vinfo->pubkey_bytes; @@ -1570,11 +1520,11 @@ free_storage(ML_KEM_KEY *key) const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type) { switch (evp_type) { - case NID_ML_KEM_512: + case EVP_PKEY_ML_KEM_512: return &vinfo_map[0]; - case NID_ML_KEM_768: + case EVP_PKEY_ML_KEM_768: return &vinfo_map[1]; - case NID_ML_KEM_1024: + case EVP_PKEY_ML_KEM_1024: return &vinfo_map[2]; } return NULL; @@ -1852,7 +1802,7 @@ int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen, * We stack-allocate these. */ # define case_encap_seed(bits) \ - case NID_ML_KEM_##bits: \ + case EVP_PKEY_ML_KEM_##bits: \ { \ scalar tmp[2 * ML_KEM_##bits##_RANK]; \ \ @@ -1931,7 +1881,7 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen, * We stack-allocate these. */ # define case_decap(bits) \ - case NID_ML_KEM_##bits: \ + case EVP_PKEY_ML_KEM_##bits: \ { \ uint8_t cbuf[CTEXT_BYTES(bits)]; \ scalar tmp[2 * ML_KEM_##bits##_RANK]; \ diff --git a/providers/implementations/encode_decode/encode_key2any.c b/providers/implementations/encode_decode/encode_key2any.c index 837f14626b5..8f5b2fa86b9 100644 --- a/providers/implementations/encode_decode/encode_key2any.c +++ b/providers/implementations/encode_decode/encode_key2any.c @@ -893,9 +893,15 @@ static int ml_kem_spki_pub_to_der(const void *vkey, unsigned char **pder, publen = key->vinfo->pubkey_bytes; if (pder != NULL - && ((*pder = OPENSSL_malloc(publen)) == NULL - || !ossl_ml_kem_encode_public_key(*pder, publen, key))) + && (*pder = OPENSSL_malloc(publen)) == NULL) return 0; + if (!ossl_ml_kem_encode_public_key(*pder, publen, key)) { + ERR_raise_data(ERR_LIB_OSSL_ENCODER, ERR_R_INTERNAL_ERROR, + "error encoding %s public key", + key->vinfo->algorithm_name); + OPENSSL_free(*pder); + return 0; + } return publen; } diff --git a/providers/implementations/keymgmt/ml_kem_kmgmt.c b/providers/implementations/keymgmt/ml_kem_kmgmt.c index e2782316368..dfb39e3813a 100644 --- a/providers/implementations/keymgmt/ml_kem_kmgmt.c +++ b/providers/implementations/keymgmt/ml_kem_kmgmt.c @@ -640,12 +640,13 @@ static void *ml_kem_dup(const void *vkey, int selection) static void *ml_kem_##bits##_new(void *provctx) \ { \ return ml_kem_new(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), \ - NULL, NID_ML_KEM_##bits); \ + NULL, EVP_PKEY_ML_KEM_##bits); \ } \ static void *ml_kem_##bits##_gen_init(void *provctx, int selection, \ const OSSL_PARAM params[]) \ { \ - return ml_kem_gen_init(provctx, selection, params, NID_ML_KEM_##bits); \ + return ml_kem_gen_init(provctx, selection, params, \ + EVP_PKEY_ML_KEM_##bits); \ } \ const OSSL_DISPATCH ossl_ml_kem_##bits##_keymgmt_functions[] = { \ { OSSL_FUNC_KEYMGMT_NEW, (OSSL_FUNC) ml_kem_##bits##_new }, \ diff --git a/providers/implementations/keymgmt/mlx_kmgmt.c b/providers/implementations/keymgmt/mlx_kmgmt.c index efd766bb79b..5b9fb612479 100644 --- a/providers/implementations/keymgmt/mlx_kmgmt.c +++ b/providers/implementations/keymgmt/mlx_kmgmt.c @@ -44,11 +44,11 @@ static const int minimal_selection = OSSL_KEYMGMT_SELECT_DOMAIN_PARAMETERS /* Must match DECLARE_DISPATCH invocations at the end of the file */ static const ECDH_VINFO hybrid_vtable[] = { - { "EC", "P-256", 65, 32, 32, 1, NID_ML_KEM_768 }, - { "EC", "P-384", 97, 48, 48, 1, NID_ML_KEM_1024 }, + { "EC", "P-256", 65, 32, 32, 1, EVP_PKEY_ML_KEM_768 }, + { "EC", "P-384", 97, 48, 48, 1, EVP_PKEY_ML_KEM_1024 }, #if !defined(OPENSSL_NO_ECX) - { "X25519", NULL, 32, 32, 32, 0, NID_ML_KEM_768 }, - { "X448", NULL, 56, 56, 56, 0, NID_ML_KEM_1024 }, + { "X25519", NULL, 32, 32, 32, 0, EVP_PKEY_ML_KEM_768 }, + { "X448", NULL, 56, 56, 56, 0, EVP_PKEY_ML_KEM_1024 }, #endif }; diff --git a/test/ml_kem_evp_extra_test.c b/test/ml_kem_evp_extra_test.c index f55e1cf48e1..f46da9d692a 100644 --- a/test/ml_kem_evp_extra_test.c +++ b/test/ml_kem_evp_extra_test.c @@ -217,7 +217,11 @@ static int test_ml_kem(void) static int test_non_derandomised_ml_kem(void) { - static const int alg[3] = { NID_ML_KEM_512, NID_ML_KEM_768, NID_ML_KEM_1024 }; + static const int alg[3] = { + EVP_PKEY_ML_KEM_512, + EVP_PKEY_ML_KEM_768, + EVP_PKEY_ML_KEM_1024 + }; EVP_RAND_CTX *privctx; EVP_RAND_CTX *pubctx; EVP_MD *sha256; diff --git a/test/ml_kem_internal_test.c b/test/ml_kem_internal_test.c index 0f847a12d40..78b2168f81a 100644 --- a/test/ml_kem_internal_test.c +++ b/test/ml_kem_internal_test.c @@ -92,7 +92,11 @@ static uint8_t ml_kem_expected_shared_secret[3][32] = { static int sanity_test(void) { - static const int alg[3] = { NID_ML_KEM_512, NID_ML_KEM_768, NID_ML_KEM_1024 }; + static const int alg[3] = { + EVP_PKEY_ML_KEM_512, + EVP_PKEY_ML_KEM_768, + EVP_PKEY_ML_KEM_1024 + }; EVP_RAND_CTX *privctx; EVP_RAND_CTX *pubctx; EVP_MD *sha256 = EVP_MD_fetch(NULL, "sha256", NULL);