#include <internal/sha3.h>
#include <crypto/ml_kem.h>
#include <openssl/rand.h>
+#include <openssl/byteorder.h>
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
#include <valgrind/memcheck.h>
typedef __owur
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
-static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s);
+static void scalar_encode(uint8_t *out, const scalar *s, int bits);
/*
* The wire-form of a losslessly encoded vector uses 12-bits per element.
#endif
-/*
- * Negligible to no performance impact, used only in ciphertext decoding when
- * extracting the |u| vector and |v| scalar.
- */
-static ossl_inline int constant_time_declassify_int(int v) {
- CONSTTIME_DECLASSIFY(&v, sizeof(v));
- return value_barrier(v);
-}
-
/*
* Per-variant fixed parameters
*/
CTEXT_BYTES(512),
VECTOR_BYTES(512),
U_VECTOR_BYTES(512),
- NID_ML_KEM_512,
+ EVP_PKEY_ML_KEM_512,
ML_KEM_512_BITS,
ML_KEM_512_RANK,
ML_KEM_512_DU,
CTEXT_BYTES(768),
VECTOR_BYTES(768),
U_VECTOR_BYTES(768),
- NID_ML_KEM_768,
+ EVP_PKEY_ML_KEM_768,
ML_KEM_768_BITS,
ML_KEM_768_RANK,
ML_KEM_768_DU,
CTEXT_BYTES(1024),
VECTOR_BYTES(1024),
U_VECTOR_BYTES(1024),
- NID_ML_KEM_1024,
+ EVP_PKEY_ML_KEM_1024,
ML_KEM_1024_BITS,
ML_KEM_1024_RANK,
ML_KEM_1024_DU,
* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
*/
static const uint16_t kNTTRoots[128] = {
- 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
- 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
- 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
- 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
- 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
- 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
- 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
- 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
- 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
- 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
+ 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
+ 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
+ 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
+ 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
+ 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
+ 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
+ 2319, 1435, 807, 452, 1438, 2868, 1534, 2402,
+ 2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
+ 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
+ 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
+ 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687,
+ 939, 2308, 2437, 2388, 733, 2337, 268, 641,
+ 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645,
+ 1063, 319, 2773, 757, 2099, 561, 2466, 2594,
+ 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
};
-/* InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
+/*
+ * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
+ * Listed in order of use in the inverse NTT loop (index 0 is skipped):
+ *
+ * 0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
+ */
static const uint16_t kInverseNTTRoots[128] = {
- 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
- 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
- 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
- 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
- 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
- 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
- 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
- 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
- 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
- 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
- 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
+ 1, 1175, 2444, 394, 1219, 2300, 1455, 2117,
+ 1607, 2443, 554, 1179, 2186, 2303, 2926, 2237,
+ 525, 735, 863, 2768, 1230, 2572, 556, 3010,
+ 2266, 1684, 1239, 780, 2954, 109, 1292, 1031,
+ 1745, 2688, 3061, 992, 2596, 941, 892, 1021,
+ 2390, 642, 1868, 2377, 1482, 1540, 540, 1678,
+ 1626, 279, 314, 1173, 2573, 3096, 48, 667,
+ 1920, 2229, 1041, 2606, 1692, 680, 2746, 568,
+ 3312, 2419, 2102, 219, 855, 2681, 1848, 712,
+ 682, 927, 1795, 461, 1891, 2877, 2522, 1894,
+ 1010, 1414, 2009, 3296, 464, 2697, 816, 1352,
+ 2679, 1274, 1052, 1025, 2132, 1573, 76, 2998,
+ 3040, 2508, 1355, 450, 936, 447, 2794, 1235,
+ 1903, 1996, 1089, 3273, 283, 1853, 1990, 882,
+ 3033, 1583, 2760, 69, 543, 2532, 3136, 1410,
+ 2267, 2481, 1432, 2699, 687, 40, 749, 1600,
};
/*
if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
return 0;
- while (t < end) {
+ do {
uint8_t buf[3 * DEGREE / 2];
- scalar_encode_12(buf, t++);
+ scalar_encode(buf, t++, 12);
if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
return 0;
- }
+ } while (t < end);
if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
return 0;
uint16_t d;
uint8_t b1, b2, b3;
-#define sample_scalar_do_bytes \
- b1 = *in++; \
- b2 = *in++; \
- b3 = *in++; \
- \
- if (curr >= endout) \
- break; \
- if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime) \
- *curr++ = d; \
- if (curr >= endout) \
- break; \
- if ((d = (b3 << 4) + (b2 >> 4)) < kPrime) \
- *curr++ = d
-
- while (curr < endout) {
- if (!EVP_DigestSqueeze(mdctx, buf, sizeof(buf)))
+ do {
+ if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
return 0;
- /* Unrolled loop: twelve bytes in, eight 12-bit *candidates* out */
- for (in = buf; in < endin;) {
- sample_scalar_do_bytes;
- sample_scalar_do_bytes;
- sample_scalar_do_bytes;
- sample_scalar_do_bytes;
- }
- }
+ do {
+ b1 = *in++;
+ b2 = *in++;
+ b3 = *in++;
+
+ if (curr >= endout)
+ break;
+ if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
+ *curr++ = d;
+ if (curr >= endout)
+ break;
+ if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
+ *curr++ = d;
+ } while (in < endin);
+ } while (curr < endout);
return 1;
-#undef sample_scalar_do_bytes
}
/*-
return reduce_once(remainder);
}
+/* Multiply a scalar by a constant. */
+static void scalar_mult_const(scalar *s, uint16_t a)
+{
+ uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
+
+ do {
+ tmp = reduce(*curr * a);
+ *curr++ = tmp;
+ } while (curr < end);
+}
+
/*-
* FIPS 203, Section 4.3, Algoritm 9: "NTT".
* In-place number theoretic transform of a given scalar. Note that ML-KEM's
*/
static void scalar_ntt(scalar *s)
{
- int offset = DEGREE;
- int k, step, i, j;
- uint32_t step_root;
- uint16_t odd, even;
+ const uint16_t *roots = kNTTRoots;
+ uint16_t *end = s->c + DEGREE;
+ int offset = DEGREE / 2;
- /*
- * `int` is used here because using `size_t` throughout caused a ~5%
- * slowdown with Clang 14 on Aarch64.
- */
- for (step = 1; step < DEGREE / 2; step <<= 1) {
- offset >>= 1;
- k = 0;
- for (i = 0; i < step; i++) {
- step_root = kNTTRoots[i + step];
- for (j = k; j < k + offset; j++) {
- odd = reduce(step_root * s->c[j + offset]);
- even = s->c[j];
- s->c[j] = reduce_once(odd + even);
- s->c[j + offset] = reduce_once(even - odd + kPrime);
- }
- k += 2 * offset;
- }
- }
+ do {
+ uint16_t *curr = s->c, *peer;
+
+ do {
+ uint16_t *pause = curr + offset, even, odd;
+ uint32_t zeta = *++roots;
+
+ peer = pause;
+ do {
+ even = *curr;
+ odd = reduce(*peer * zeta);
+ *peer++ = reduce_once(even - odd + kPrime);
+ *curr++ = reduce_once(odd + even);
+ } while (curr < pause);
+ } while ((curr = peer) < end);
+ } while ((offset >>= 1) >= 2);
}
/*-
* the number theoretic transform, this leaves off the first step of the normal
* iFFT to account for the fact that 3329 does not have a 512th root of unity,
* using the precomputed 128 roots of unity stored in InverseNTTRoots.
- *
- * FIPS 203, Algorithm 10, performs this transformation in a slightly different
- * manner, using the same NTTRoots table as the forward NTT transform.
*/
static void scalar_inverse_ntt(scalar *s)
{
- int step = DEGREE / 2;
- int offset, k, i, j;
- uint32_t step_root;
- uint16_t odd, even;
+ const uint16_t *roots = kInverseNTTRoots;
+ uint16_t *end = s->c + DEGREE;
+ int offset = 2;
- /*
- * `int` is used here because using `size_t` throughout caused a ~5%
- * slowdown with Clang 14 on Aarch64.
- */
- for (offset = 2; offset < DEGREE; offset <<= 1) {
- step >>= 1;
- k = 0;
- for (i = 0; i < step; i++) {
- step_root = kInverseNTTRoots[i + step];
- for (j = k; j < k + offset; j++) {
- odd = s->c[j + offset];
- even = s->c[j];
- s->c[j] = reduce_once(odd + even);
- s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
- }
- k += 2 * offset;
- }
- }
- for (i = 0; i < DEGREE; i++)
- s->c[i] = reduce(s->c[i] * kInverseDegree);
+ do {
+ uint16_t *curr = s->c, *peer;
+
+ do {
+ uint16_t *pause = curr + offset, even, odd;
+ uint32_t zeta = *++roots;
+
+ peer = pause;
+ do {
+ even = *curr;
+ odd = *peer;
+ *peer++ = reduce(zeta * (even - odd + kPrime));
+ *curr++ = reduce_once(odd + even);
+ } while (curr < pause);
+ } while ((curr = peer) < end);
+ } while ((offset <<= 1) < DEGREE);
+ scalar_mult_const(s, kInverseDegree);
}
/* Addition updating the LHS scalar in-place. */
const uint16_t *lc = lhs->c, *rc = rhs->c;
const uint16_t *roots = kModRoots;
- while (curr < end) {
+ do {
uint32_t l0 = *lc++, r0 = *rc++;
uint32_t l1 = *lc++, r1 = *rc++;
uint32_t zetapow = *roots++;
*curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
*curr++ = reduce(l0 * r1 + l1 * r0);
- }
+ } while (curr < end);
}
/* Above, but add the result to an existing scalar */
const uint16_t *lc = lhs->c, *rc = rhs->c;
const uint16_t *roots = kModRoots;
- while (curr < end) {
+ do {
uint32_t l0 = *lc++, r0 = *rc++;
uint32_t l1 = *lc++, r1 = *rc++;
uint16_t *c0 = curr++;
*c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
*c1 = reduce(*c1 + l0 * r1 + l1 * r0);
- }
+ } while (curr < end);
}
-static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
- 0x1f, 0x3f, 0x7f, 0xff};
-
/*-
- * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<12.
- * Here |bits| is |d|. For efficiency, we handle the d=1, and d=12 cases
- * separately.
+ * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
+ * Here |bits| is |d|. For efficiency, we handle the d=1 case separately.
*/
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
-{
- uint8_t out_byte = 0;
- int out_byte_bits = 0;
- int i, element_bits_done, chunk_bits, out_bits_remaining;
- uint16_t element;
-
- for (i = 0; i < DEGREE; i++) {
- element = s->c[i];
- element_bits_done = 0;
- while (element_bits_done < bits) {
- chunk_bits = bits - element_bits_done;
- out_bits_remaining = 8 - out_byte_bits;
- if (chunk_bits >= out_bits_remaining) {
- chunk_bits = out_bits_remaining;
- out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
- *out++ = out_byte;
- out_byte_bits = 0;
- out_byte = 0;
- } else {
- out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
- out_byte_bits += chunk_bits;
- }
- element_bits_done += chunk_bits;
- element >>= chunk_bits;
- }
- }
- if (out_byte_bits > 0)
- *out = out_byte;
-}
-
-/*
- * scalar_encode_12 is |scalar_encode| specialised for |bits| == 12.
- */
-static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s)
{
const uint16_t *curr = s->c, *end = curr + DEGREE;
- uint16_t c1, c2;
-
-#define encode_two \
- c1 = *curr++; \
- c2 = *curr++; \
- *out++ = (uint8_t) c1; \
- *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4)); \
- *out++ = (uint8_t) (c2 >> 4)
-
- while (curr < end) {
- encode_two;
- encode_two;
- encode_two;
- encode_two;
- }
-#undef encode_two
+ uint64_t accum = 0, element;
+ int used = 0;
+
+ do {
+ element = *curr++;
+ if (used + bits < 64) {
+ accum |= element << used;
+ used += bits;
+ } else if (used + bits > 64) {
+ out = OPENSSL_store_u64_le(out, accum | (element << used));
+ accum = element >> (64 - used);
+ used = (used + bits) - 64;
+ } else {
+ out = OPENSSL_store_u64_le(out, accum | (element << used));
+ accum = 0;
+ used = 0;
+ }
+ } while (curr < end);
}
/*
* separately.
*
* scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
- * |out|. It returns one on success and zero if any parsed value is >=
- * |kPrime|.
- *
- * Note: Used in decrypt_cpa(), which returns void and so does not check the
- * return value of this function. But also used in vector_decode(), which
- * returns early when scalar_decode() fails.
+ * |out|.
*/
-static int scalar_decode(scalar *out, const uint8_t *in, int bits)
+static void scalar_decode(scalar *out, const uint8_t *in, int bits)
{
- uint8_t in_byte = 0;
- int in_byte_bits_left = 0;
- int i, element_bits_done, chunk_bits;
- uint16_t element;
-
- for (i = 0; i < DEGREE; i++) {
- element = 0;
- element_bits_done = 0;
- while (element_bits_done < bits) {
- if (in_byte_bits_left == 0) {
- in_byte = *in;
- in++;
- in_byte_bits_left = 8;
- }
- chunk_bits = bits - element_bits_done;
- if (chunk_bits > in_byte_bits_left)
- chunk_bits = in_byte_bits_left;
- element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
- in_byte_bits_left -= chunk_bits;
- in_byte >>= chunk_bits;
- element_bits_done += chunk_bits;
+ uint16_t *curr = out->c, *end = curr + DEGREE;
+ uint64_t accum = 0;
+ int accum_bits = 0, todo = bits;
+ uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
+ uint16_t element = 0;
+
+ do {
+ if (accum_bits == 0) {
+ in = OPENSSL_load_u64_le(&accum, in);
+ accum_bits = 64;
}
- if (constant_time_declassify_int(element >= kPrime))
- return 0;
- out->c[i] = element;
- }
- return 1;
+ if (todo == bits && accum_bits >= bits) {
+ /* No partial "element", and all the required bits available */
+ *curr++ = ((uint16_t) accum) & mask;
+ accum >>= bits;
+ accum_bits -= bits;
+ } else if (accum_bits >= todo) {
+ /* A partial "element", and all the required bits available */
+ *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
+ accum >>= todo;
+ accum_bits -= todo;
+ element = 0;
+ todo = bits;
+ mask = bitmask;
+ } else {
+ /*
+ * Only some of the requisite bits accumulated, store |accum_bits|
+ * of these in |element|. The accumulated bitcount becomes 0, but
+ * as soon as we have more bits we'll want to merge accum_bits
+ * fewer of them into the final |element|.
+ *
+ * Note that with a 64-bit accumulator and |bits| always 12 or
+ * less, if we're here, the previous iteration had all the
+ * requisite bits, and so there are no kept bits in |element|.
+ */
+ element = ((uint16_t) accum) & mask;
+ todo -= accum_bits;
+ mask = bitmask >> accum_bits;
+ accum_bits = 0;
+ }
+ } while (curr < end);
}
static __owur
b >>= 1
/* Unrolled to process each byte in one iteration */
- while (curr < end) {
+ do {
b = *in++;
decode_decompress_add_bit;
decode_decompress_add_bit;
decode_decompress_add_bit;
decode_decompress_add_bit;
decode_decompress_add_bit;
- }
+ } while (curr < end);
#undef decode_decompress_add_bit
}
/* Addition updating the LHS vector in-place. */
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
{
- while (rank-- > 0)
+ do {
scalar_add(lhs++, rhs++);
+ } while (--rank > 0);
}
/*
int stride = bits * DEGREE / 8;
for (; rank-- > 0; in += stride, ++out) {
- if (!scalar_decode(out, in, bits))
- return;
+ scalar_decode(out, in, bits);
scalar_decompress(out, bits);
scalar_ntt(out);
}
}
-/* vector_encode(), specialised to bits == 12. */
-static void vector_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *a,
- int rank)
-{
- int stride = 3 * DEGREE / 2;
-
- for (; rank-- > 0; out += stride)
- scalar_encode_12(out, a++);
-}
-
/* vector_decode(), specialised to bits == 12. */
static __owur
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
/* In-place compression of each scalar component */
static void vector_compress(scalar *a, int bits, int rank)
{
- while (rank-- > 0)
+ do {
scalar_compress(a++, bits);
+ } while (--rank > 0);
}
/* The output scalar must not overlap with the inputs */
if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
return 0;
- /*
- * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| for a
- * discussion on why the value barrier is by default omitted. While this
- * could have been written reduce_once(value + kPrime), this is one extra
- * addition and small range of |value| tempts some versions of Clang to
- * emit a branch.
- */
-#define cbd_2_do_byte \
- b = *r++; \
- \
- value = bit0(b) + bitn(1, b); \
- value -= bitn(2, b) + bitn(3, b); \
- mask = constish_time_non_zero(value >> 15); \
- *curr++ = value + (kPrime & mask); \
- \
- value = bitn(4, b) + bitn(5, b); \
- value -= bitn(6, b) + bitn(7, b); \
- mask = constish_time_non_zero(value >> 15); \
- *curr++ = value + (kPrime & mask)
-
- /* Unrolled, 4 random bytes in, 8 coefficients out */
- while (curr < end) {
- cbd_2_do_byte;
- cbd_2_do_byte;
- cbd_2_do_byte;
- cbd_2_do_byte;
- }
+ do {
+ b = *r++;
+
+ /*
+ * Add |kPrime| if |value| underflowed. See |constish_time_non_zero|
+ * for a discussion on why the value barrier is by default omitted.
+ * While this could have been written reduce_once(value + kPrime), this
+ * is one extra addition and small range of |value| tempts some
+ * versions of Clang to emit a branch.
+ */
+ value = bit0(b) + bitn(1, b);
+ value -= bitn(2, b) + bitn(3, b);
+ mask = constish_time_non_zero(value >> 15);
+ *curr++ = value + (kPrime & mask);
+
+ value = bitn(4, b) + bitn(5, b);
+ value -= bitn(6, b) + bitn(7, b);
+ mask = constish_time_non_zero(value >> 15);
+ *curr++ = value + (kPrime & mask);
+ } while (curr < end);
return 1;
-#undef cbd_2_do_byte
}
/*
if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
return 0;
- /*
- * Add |kPrime| if |value| underflowed. See |constish_time_non_zero| for a
- * discussion on why the value barrier is by default omitted. While this
- * could have been written reduce_once(value + kPrime), this is one extra
- * addition and small range of |value| tempts some versions of Clang to
- * emit a branch.
- */
-#define cbd_3_do_bytes \
- b1 = *r++; \
- b2 = *r++; \
- b3 = *r++; \
- \
- value = bit0(b1) + bitn(1, b1) + bitn(2, b1); \
- value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1); \
- mask = constish_time_non_zero(value >> 15); \
- *curr++ = value + (kPrime & mask); \
- \
- value = bitn(6, b1) + bitn(7, b1) + bit0(b2); \
- value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2); \
- mask = constish_time_non_zero(value >> 15); \
- *curr++ = value + (kPrime & mask); \
- \
- value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2); \
- value -= bitn(7, b2) + bit0(b3) + bitn(1, b3); \
- mask = constish_time_non_zero(value >> 15); \
- *curr++ = value + (kPrime & mask); \
- \
- value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3); \
- value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3); \
- mask = constish_time_non_zero(value >> 15); \
+ do {
+ b1 = *r++;
+ b2 = *r++;
+ b3 = *r++;
+
+ /*
+ * Add |kPrime| if |value| underflowed. See |constish_time_non_zero|
+ * for a discussion on why the value barrier is by default omitted.
+ * While this could have been written reduce_once(value + kPrime), this
+ * is one extra addition and small range of |value| tempts some
+ * versions of Clang to emit a branch.
+ */
+ value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
+ value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
+ mask = constish_time_non_zero(value >> 15);
*curr++ = value + (kPrime & mask);
- /* Unrolled, 6 random bytes in, 8 coefficients out */
- while (curr < end) {
- cbd_3_do_bytes;
- cbd_3_do_bytes;
- }
+ value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
+ value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
+ mask = constish_time_non_zero(value >> 15);
+ *curr++ = value + (kPrime & mask);
+
+ value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
+ value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
+ mask = constish_time_non_zero(value >> 15);
+ *curr++ = value + (kPrime & mask);
+
+ value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
+ value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
+ mask = constish_time_non_zero(value >> 15);
+ *curr++ = value + (kPrime & mask);
+ } while (curr < end);
return 1;
-#undef cbd_3_do_bytes
}
/*
uint8_t input[ML_KEM_RANDOM_BYTES + 1];
memcpy(input, seed, ML_KEM_RANDOM_BYTES);
- while (rank-- > 0) {
+ do {
input[ML_KEM_RANDOM_BYTES] = (*counter)++;
if (!cbd(out++, input, mdctx, key))
return 0;
- }
+ } while (--rank > 0);
return 1;
}
uint8_t input[ML_KEM_RANDOM_BYTES + 1];
memcpy(input, seed, ML_KEM_RANDOM_BYTES);
- while (rank-- > 0) {
+ do {
input[ML_KEM_RANDOM_BYTES] = (*counter)++;
if (!cbd(out, input, mdctx, key))
return 0;
scalar_ntt(out++);
- }
+ } while (--rank > 0);
return 1;
}
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
-#define CBD1(evp_type) ((evp_type) == NID_ML_KEM_512 ? cbd_3 : cbd_2)
+#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
/*
* FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
const uint8_t *rho = key->rho;
const ML_KEM_VINFO *vinfo = key->vinfo;
- vector_encode_12(out, key->t, vinfo->rank);
+ vector_encode(out, key->t, 12, vinfo->rank);
memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
}
{
const ML_KEM_VINFO *vinfo = key->vinfo;
- vector_encode_12(out, key->s, vinfo->rank);
+ vector_encode(out, key->s, 12, vinfo->rank);
out += vinfo->vector_bytes;
encode_pubkey(out, key);
out += vinfo->pubkey_bytes;
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
{
switch (evp_type) {
- case NID_ML_KEM_512:
+ case EVP_PKEY_ML_KEM_512:
return &vinfo_map[0];
- case NID_ML_KEM_768:
+ case EVP_PKEY_ML_KEM_768:
return &vinfo_map[1];
- case NID_ML_KEM_1024:
+ case EVP_PKEY_ML_KEM_1024:
return &vinfo_map[2];
}
return NULL;
* We stack-allocate these.
*/
# define case_encap_seed(bits) \
- case NID_ML_KEM_##bits: \
+ case EVP_PKEY_ML_KEM_##bits: \
{ \
scalar tmp[2 * ML_KEM_##bits##_RANK]; \
\
* We stack-allocate these.
*/
# define case_decap(bits) \
- case NID_ML_KEM_##bits: \
+ case EVP_PKEY_ML_KEM_##bits: \
{ \
uint8_t cbuf[CTEXT_BYTES(bits)]; \
scalar tmp[2 * ML_KEM_##bits##_RANK]; \