]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
ML-KEM implementation cleanup/speedup
authorViktor Dukhovni <openssl-users@dukhovni.org>
Mon, 13 Jan 2025 17:34:37 +0000 (04:34 +1100)
committerTomas Mraz <tomas@openssl.org>
Fri, 14 Feb 2025 09:50:58 +0000 (10:50 +0100)
Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/26341)

crypto/ml_kem/ml_kem.c
providers/implementations/encode_decode/encode_key2any.c
providers/implementations/keymgmt/ml_kem_kmgmt.c
providers/implementations/keymgmt/mlx_kmgmt.c
test/ml_kem_evp_extra_test.c
test/ml_kem_internal_test.c

index a81fe8df5cb7ad964f3a13e65dc416c3cb0858b3..e4713add9bb569044e40467cb295f8cecb89f858 100644 (file)
@@ -12,6 +12,7 @@
 #include <internal/sha3.h>
 #include <crypto/ml_kem.h>
 #include <openssl/rand.h>
+#include <openssl/byteorder.h>
 
 #if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
 #include <valgrind/memcheck.h>
@@ -117,7 +118,7 @@ DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
 typedef __owur
 int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
                 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
-static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s);
+static void scalar_encode(uint8_t *out, const scalar *s, int bits);
 
 /*
  * The wire-form of a losslessly encoded vector uses 12-bits per element.
@@ -166,15 +167,6 @@ static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s);
 
 #endif
 
-/*
- * Negligible to no performance impact, used only in ciphertext decoding when
- * extracting the |u| vector and |v| scalar.
- */
-static ossl_inline int constant_time_declassify_int(int v) {
-  CONSTTIME_DECLASSIFY(&v, sizeof(v));
-  return value_barrier(v);
-}
-
 /*
  * Per-variant fixed parameters
  */
@@ -188,7 +180,7 @@ static const ML_KEM_VINFO vinfo_map[3] = {
         CTEXT_BYTES(512),
         VECTOR_BYTES(512),
         U_VECTOR_BYTES(512),
-        NID_ML_KEM_512,
+        EVP_PKEY_ML_KEM_512,
         ML_KEM_512_BITS,
         ML_KEM_512_RANK,
         ML_KEM_512_DU,
@@ -204,7 +196,7 @@ static const ML_KEM_VINFO vinfo_map[3] = {
         CTEXT_BYTES(768),
         VECTOR_BYTES(768),
         U_VECTOR_BYTES(768),
-        NID_ML_KEM_768,
+        EVP_PKEY_ML_KEM_768,
         ML_KEM_768_BITS,
         ML_KEM_768_RANK,
         ML_KEM_768_DU,
@@ -220,7 +212,7 @@ static const ML_KEM_VINFO vinfo_map[3] = {
         CTEXT_BYTES(1024),
         VECTOR_BYTES(1024),
         U_VECTOR_BYTES(1024),
-        NID_ML_KEM_1024,
+        EVP_PKEY_ML_KEM_1024,
         ML_KEM_1024_BITS,
         ML_KEM_1024_RANK,
         ML_KEM_1024_DU,
@@ -259,32 +251,47 @@ static const uint16_t kInverseDegree = INVERSE_DEGREE;
  * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
  */
 static const uint16_t kNTTRoots[128] = {
-    1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
-    2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
-    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
-    1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
-    2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
-    2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
-    1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
-    1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
-    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
-    2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
+    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
+    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
+    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
+    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
+    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
+    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
+    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
+    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
+    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
+    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
+    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
+    939,  2308, 2437, 2388, 733,  2337, 268,  641,
+    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
+    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
+    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
     1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
 };
 
-/* InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
+/*
+ * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
+ * Listed in order of use in the inverse NTT loop (index 0 is skipped):
+ *
+ *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
+ */
 static const uint16_t kInverseNTTRoots[128] = {
-    1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
-    2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
-    1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
-    2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
-    1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
-    1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
-    2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
-    2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
-    2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
-    1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
-    2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
+    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
+    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
+    525,  735,  863,  2768, 1230, 2572, 556,  3010,
+    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
+    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
+    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
+    1626, 279,  314,  1173, 2573, 3096, 48,   667,
+    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
+    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
+    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
+    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
+    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
+    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
+    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
+    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
+    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
 };
 
 /*
@@ -361,13 +368,13 @@ hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
     if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
         return 0;
 
-    while (t < end) {
+    do {
         uint8_t buf[3 * DEGREE / 2];
 
-        scalar_encode_12(buf, t++);
+        scalar_encode(buf, t++, 12);
         if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
             return 0;
-    }
+    } while (t < end);
 
     if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
         return 0;
@@ -421,33 +428,25 @@ int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
     uint16_t d;
     uint8_t b1, b2, b3;
 
-#define sample_scalar_do_bytes                              \
-        b1 = *in++;                                         \
-        b2 = *in++;                                         \
-        b3 = *in++;                                         \
-                                                            \
-        if (curr >= endout)                                 \
-            break;                                          \
-        if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)         \
-            *curr++ = d;                                    \
-        if (curr >= endout)                                 \
-            break;                                          \
-        if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)           \
-            *curr++ = d
-
-    while (curr < endout) {
-        if (!EVP_DigestSqueeze(mdctx, buf, sizeof(buf)))
+    do {
+        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
             return 0;
-        /* Unrolled loop: twelve bytes in, eight 12-bit *candidates* out */
-        for (in = buf; in < endin;) {
-            sample_scalar_do_bytes;
-            sample_scalar_do_bytes;
-            sample_scalar_do_bytes;
-            sample_scalar_do_bytes;
-        }
-    }
+        do {
+            b1 = *in++;
+            b2 = *in++;
+            b3 = *in++;
+
+            if (curr >= endout)
+                break;
+            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
+                *curr++ = d;
+            if (curr >= endout)
+                break;
+            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
+                *curr++ = d;
+        } while (in < endin);
+    } while (curr < endout);
     return 1;
-#undef sample_scalar_do_bytes
 }
 
 /*-
@@ -480,6 +479,17 @@ static __owur uint16_t reduce(uint32_t x)
     return reduce_once(remainder);
 }
 
+/* Multiply a scalar by a constant. */
+static void scalar_mult_const(scalar *s, uint16_t a)
+{
+    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
+
+    do {
+        tmp = reduce(*curr * a);
+        *curr++ = tmp;
+    } while (curr < end);
+}
+
 /*-
  * FIPS 203, Section 4.3, Algoritm 9: "NTT".
  * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
@@ -491,29 +501,26 @@ static __owur uint16_t reduce(uint32_t x)
  */
 static void scalar_ntt(scalar *s)
 {
-    int offset = DEGREE;
-    int k, step, i, j;
-    uint32_t step_root;
-    uint16_t odd, even;
+    const uint16_t *roots = kNTTRoots;
+    uint16_t *end = s->c + DEGREE;
+    int offset = DEGREE / 2;
 
-    /*
-     * `int` is used here because using `size_t` throughout caused a ~5%
-     * slowdown with Clang 14 on Aarch64.
-     */
-    for (step = 1; step < DEGREE / 2; step <<= 1) {
-        offset >>= 1;
-        k = 0;
-        for (i = 0; i < step; i++) {
-            step_root = kNTTRoots[i + step];
-            for (j = k; j < k + offset; j++) {
-                odd = reduce(step_root * s->c[j + offset]);
-                even = s->c[j];
-                s->c[j] = reduce_once(odd + even);
-                s->c[j + offset] = reduce_once(even - odd + kPrime);
-            }
-            k += 2 * offset;
-        }
-    }
+    do {
+        uint16_t *curr = s->c, *peer;
+
+        do {
+            uint16_t *pause = curr + offset, even, odd;
+            uint32_t zeta = *++roots;
+
+            peer = pause;
+            do {
+                even = *curr;
+                odd = reduce(*peer * zeta);
+                *peer++ = reduce_once(even - odd + kPrime);
+                *curr++ = reduce_once(odd + even);
+            } while (curr < pause);
+        } while ((curr = peer) < end);
+    } while ((offset >>= 1) >= 2);
 }
 
 /*-
@@ -523,37 +530,30 @@ static void scalar_ntt(scalar *s)
  * the number theoretic transform, this leaves off the first step of the normal
  * iFFT to account for the fact that 3329 does not have a 512th root of unity,
  * using the precomputed 128 roots of unity stored in InverseNTTRoots.
- *
- * FIPS 203, Algorithm 10, performs this transformation in a slightly different
- * manner, using the same NTTRoots table as the forward NTT transform.
  */
 static void scalar_inverse_ntt(scalar *s)
 {
-    int step = DEGREE / 2;
-    int offset, k, i, j;
-    uint32_t step_root;
-    uint16_t odd, even;
+    const uint16_t *roots = kInverseNTTRoots;
+    uint16_t *end = s->c + DEGREE;
+    int offset = 2;
 
-    /*
-     * `int` is used here because using `size_t` throughout caused a ~5%
-     * slowdown with Clang 14 on Aarch64.
-     */
-    for (offset = 2; offset < DEGREE; offset <<= 1) {
-        step >>= 1;
-        k = 0;
-        for (i = 0; i < step; i++) {
-            step_root = kInverseNTTRoots[i + step];
-            for (j = k; j < k + offset; j++) {
-                odd = s->c[j + offset];
-                even = s->c[j];
-                s->c[j] = reduce_once(odd + even);
-                s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
-            }
-            k += 2 * offset;
-        }
-    }
-    for (i = 0; i < DEGREE; i++)
-        s->c[i] = reduce(s->c[i] * kInverseDegree);
+    do {
+        uint16_t *curr = s->c, *peer;
+
+        do {
+            uint16_t *pause = curr + offset, even, odd;
+            uint32_t zeta = *++roots;
+
+            peer = pause;
+            do {
+                even = *curr;
+                odd = *peer;
+                *peer++ = reduce(zeta * (even - odd + kPrime));
+                *curr++ = reduce_once(odd + even);
+            } while (curr < pause);
+        } while ((curr = peer) < end);
+    } while ((offset <<= 1) < DEGREE);
+    scalar_mult_const(s, kInverseDegree);
 }
 
 /* Addition updating the LHS scalar in-place. */
@@ -592,14 +592,14 @@ static void scalar_mult(scalar *out, const scalar *lhs,
     const uint16_t *lc = lhs->c, *rc = rhs->c;
     const uint16_t *roots = kModRoots;
 
-    while (curr < end) {
+    do {
         uint32_t l0 = *lc++, r0 = *rc++;
         uint32_t l1 = *lc++, r1 = *rc++;
         uint32_t zetapow = *roots++;
 
         *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
         *curr++ = reduce(l0 * r1 + l1 * r0);
-    }
+    } while (curr < end);
 }
 
 /* Above, but add the result to an existing scalar */
@@ -611,7 +611,7 @@ void scalar_mult_add(scalar *out, const scalar *lhs,
     const uint16_t *lc = lhs->c, *rc = rhs->c;
     const uint16_t *roots = kModRoots;
 
-    while (curr < end) {
+    do {
         uint32_t l0 = *lc++, r0 = *rc++;
         uint32_t l1 = *lc++, r1 = *rc++;
         uint16_t *c0 = curr++;
@@ -620,70 +620,34 @@ void scalar_mult_add(scalar *out, const scalar *lhs,
 
         *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
         *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
-    }
+    } while (curr < end);
 }
 
-static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
-                                  0x1f, 0x3f, 0x7f, 0xff};
-
 /*-
- * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<12.
- * Here |bits| is |d|.  For efficiency, we handle the d=1, and d=12 cases
- * separately.
+ * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
+ * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
  */
 static void scalar_encode(uint8_t *out, const scalar *s, int bits)
-{
-    uint8_t out_byte = 0;
-    int out_byte_bits = 0;
-    int i, element_bits_done, chunk_bits, out_bits_remaining;
-    uint16_t element;
-
-    for (i = 0; i < DEGREE; i++) {
-        element = s->c[i];
-        element_bits_done = 0;
-        while (element_bits_done < bits) {
-            chunk_bits = bits - element_bits_done;
-            out_bits_remaining = 8 - out_byte_bits;
-            if (chunk_bits >= out_bits_remaining) {
-                chunk_bits = out_bits_remaining;
-                out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
-                *out++ = out_byte;
-                out_byte_bits = 0;
-                out_byte = 0;
-            } else {
-                out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
-                out_byte_bits += chunk_bits;
-            }
-            element_bits_done += chunk_bits;
-            element >>= chunk_bits;
-        }
-    }
-    if (out_byte_bits > 0)
-        *out = out_byte;
-}
-
-/*
- * scalar_encode_12 is |scalar_encode| specialised for |bits| == 12.
- */
-static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s)
 {
     const uint16_t *curr = s->c, *end = curr + DEGREE;
-    uint16_t c1, c2;
-
-#define encode_two                                                      \
-        c1 = *curr++;                                                   \
-        c2 = *curr++;                                                   \
-        *out++ = (uint8_t) c1;                                          \
-        *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4));   \
-        *out++ = (uint8_t) (c2 >> 4)
-
-    while (curr < end) {
-        encode_two;
-        encode_two;
-        encode_two;
-        encode_two;
-    }
-#undef encode_two
+    uint64_t accum = 0, element;
+    int used = 0;
+
+    do {
+        element = *curr++;
+        if (used + bits < 64) {
+            accum |= element << used;
+            used += bits;
+        } else if (used + bits > 64) {
+            out = OPENSSL_store_u64_le(out, accum | (element << used));
+            accum = element >> (64 - used);
+            used = (used + bits) - 64;
+        } else {
+            out = OPENSSL_store_u64_le(out, accum | (element << used));
+            accum = 0;
+            used = 0;
+        }
+    } while (curr < end);
 }
 
 /*
@@ -709,42 +673,51 @@ static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
  * separately.
  *
  * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
- * |out|. It returns one on success and zero if any parsed value is >=
- * |kPrime|.
- *
- * Note: Used in decrypt_cpa(), which returns void and so does not check the
- * return value of this function.  But also used in vector_decode(), which
- * returns early when scalar_decode() fails.
+ * |out|.
  */
-static int scalar_decode(scalar *out, const uint8_t *in, int bits)
+static void scalar_decode(scalar *out, const uint8_t *in, int bits)
 {
-    uint8_t in_byte = 0;
-    int in_byte_bits_left = 0;
-    int i, element_bits_done, chunk_bits;
-    uint16_t element;
-
-    for (i = 0; i < DEGREE; i++) {
-        element = 0;
-        element_bits_done = 0;
-        while (element_bits_done < bits) {
-            if (in_byte_bits_left == 0) {
-                in_byte = *in;
-                in++;
-                in_byte_bits_left = 8;
-            }
-            chunk_bits = bits - element_bits_done;
-            if (chunk_bits > in_byte_bits_left)
-                chunk_bits = in_byte_bits_left;
-            element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
-            in_byte_bits_left -= chunk_bits;
-            in_byte >>= chunk_bits;
-            element_bits_done += chunk_bits;
+    uint16_t *curr = out->c, *end = curr + DEGREE;
+    uint64_t accum = 0;
+    int accum_bits = 0, todo = bits;
+    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
+    uint16_t element = 0;
+
+    do {
+        if (accum_bits == 0) {
+            in = OPENSSL_load_u64_le(&accum, in);
+            accum_bits = 64;
         }
-        if (constant_time_declassify_int(element >= kPrime))
-            return 0;
-        out->c[i] = element;
-    }
-    return 1;
+        if (todo == bits && accum_bits >= bits) {
+            /* No partial "element", and all the required bits available */
+            *curr++ = ((uint16_t) accum) & mask;
+            accum >>= bits;
+            accum_bits -= bits;
+        } else if (accum_bits >= todo) {
+            /* A partial "element", and all the required bits available */
+            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
+            accum >>= todo;
+            accum_bits -= todo;
+            element = 0;
+            todo = bits;
+            mask = bitmask;
+        } else {
+            /*
+             * Only some of the requisite bits accumulated, store |accum_bits|
+             * of these in |element|.  The accumulated bitcount becomes 0, but
+             * as soon as we have more bits we'll want to merge accum_bits
+             * fewer of them into the final |element|.
+             *
+             * Note that with a 64-bit accumulator and |bits| always 12 or
+             * less, if we're here, the previous iteration had all the
+             * requisite bits, and so there are no kept bits in |element|.
+             */
+            element = ((uint16_t) accum) & mask;
+            todo -= accum_bits;
+            mask = bitmask >> accum_bits;
+            accum_bits = 0;
+        }
+    } while (curr < end);
 }
 
 static __owur
@@ -798,7 +771,7 @@ scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
         b >>= 1
 
     /* Unrolled to process each byte in one iteration */
-    while (curr < end) {
+    do {
         b = *in++;
         decode_decompress_add_bit;
         decode_decompress_add_bit;
@@ -809,7 +782,7 @@ scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
         decode_decompress_add_bit;
         decode_decompress_add_bit;
         decode_decompress_add_bit;
-    }
+    } while (curr < end);
 #undef decode_decompress_add_bit
 }
 
@@ -893,8 +866,9 @@ static void scalar_decompress(scalar *s, int bits)
 /* Addition updating the LHS vector in-place. */
 static void vector_add(scalar *lhs, const scalar *rhs, int rank)
 {
-    while (rank-- > 0)
+    do {
         scalar_add(lhs++, rhs++);
+    } while (--rank > 0);
 }
 
 /*
@@ -925,23 +899,12 @@ vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
     int stride = bits * DEGREE / 8;
 
     for (; rank-- > 0; in += stride, ++out) {
-        if (!scalar_decode(out, in, bits))
-            return;
+        scalar_decode(out, in, bits);
         scalar_decompress(out, bits);
         scalar_ntt(out);
     }
 }
 
-/* vector_encode(), specialised to bits == 12. */
-static void vector_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *a,
-                             int rank)
-{
-    int stride = 3 * DEGREE / 2;
-
-    for (; rank-- > 0; out += stride)
-        scalar_encode_12(out, a++);
-}
-
 /* vector_decode(), specialised to bits == 12. */
 static __owur
 int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
@@ -957,8 +920,9 @@ int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
 /* In-place compression of each scalar component */
 static void vector_compress(scalar *a, int bits, int rank)
 {
-    while (rank-- > 0)
+    do {
         scalar_compress(a++, bits);
+    } while (--rank > 0);
 }
 
 /* The output scalar must not overlap with the inputs */
@@ -1051,35 +1015,27 @@ int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
     if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
         return 0;
 
-    /*
-     * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero| for a
-     * discussion on why the value barrier is by default omitted.  While this
-     * could have been written reduce_once(value + kPrime), this is one extra
-     * addition and small range of |value| tempts some versions of Clang to
-     * emit a branch.
-     */
-#define cbd_2_do_byte                               \
-        b = *r++;                                   \
-                                                    \
-        value = bit0(b) + bitn(1, b);               \
-        value -= bitn(2, b) + bitn(3, b);           \
-        mask = constish_time_non_zero(value >> 15); \
-        *curr++ = value + (kPrime & mask);          \
-                                                    \
-        value = bitn(4, b) + bitn(5, b);            \
-        value -= bitn(6, b) + bitn(7, b);           \
-        mask = constish_time_non_zero(value >> 15); \
-        *curr++ = value + (kPrime & mask)
-
-    /* Unrolled, 4 random bytes in, 8 coefficients out */
-    while (curr < end) {
-        cbd_2_do_byte;
-        cbd_2_do_byte;
-        cbd_2_do_byte;
-        cbd_2_do_byte;
-    }
+    do {
+        b = *r++;
+
+        /*
+         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
+         * for a discussion on why the value barrier is by default omitted.
+         * While this could have been written reduce_once(value + kPrime), this
+         * is one extra addition and small range of |value| tempts some
+         * versions of Clang to emit a branch.
+         */
+        value = bit0(b) + bitn(1, b);
+        value -= bitn(2, b) + bitn(3, b);
+        mask = constish_time_non_zero(value >> 15);
+        *curr++ = value + (kPrime & mask);
+
+        value = bitn(4, b) + bitn(5, b);
+        value -= bitn(6, b) + bitn(7, b);
+        mask = constish_time_non_zero(value >> 15);
+        *curr++ = value + (kPrime & mask);
+    } while (curr < end);
     return 1;
-#undef cbd_2_do_byte
 }
 
 /*
@@ -1100,45 +1056,39 @@ int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
     if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
         return 0;
 
-    /*
-     * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero| for a
-     * discussion on why the value barrier is by default omitted.  While this
-     * could have been written reduce_once(value + kPrime), this is one extra
-     * addition and small range of |value| tempts some versions of Clang to
-     * emit a branch.
-     */
-#define cbd_3_do_bytes                                      \
-        b1 = *r++;                                          \
-        b2 = *r++;                                          \
-        b3 = *r++;                                          \
-                                                            \
-        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);       \
-        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);  \
-        mask = constish_time_non_zero(value >> 15);         \
-        *curr++ = value + (kPrime & mask);                  \
-                                                            \
-        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);       \
-        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);   \
-        mask = constish_time_non_zero(value >> 15);         \
-        *curr++ = value + (kPrime & mask);                  \
-                                                            \
-        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);    \
-        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);      \
-        mask = constish_time_non_zero(value >> 15);         \
-        *curr++ = value + (kPrime & mask);                  \
-                                                            \
-        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);    \
-        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);   \
-        mask = constish_time_non_zero(value >> 15);         \
+    do {
+        b1 = *r++;
+        b2 = *r++;
+        b3 = *r++;
+
+        /*
+         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
+         * for a discussion on why the value barrier is by default omitted.
+         * While this could have been written reduce_once(value + kPrime), this
+         * is one extra addition and small range of |value| tempts some
+         * versions of Clang to emit a branch.
+         */
+        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
+        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
+        mask = constish_time_non_zero(value >> 15);
         *curr++ = value + (kPrime & mask);
 
-    /* Unrolled, 6 random bytes in, 8 coefficients out */
-    while (curr < end) {
-        cbd_3_do_bytes;
-        cbd_3_do_bytes;
-    }
+        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
+        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
+        mask = constish_time_non_zero(value >> 15);
+        *curr++ = value + (kPrime & mask);
+
+        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
+        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
+        mask = constish_time_non_zero(value >> 15);
+        *curr++ = value + (kPrime & mask);
+
+        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
+        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
+        mask = constish_time_non_zero(value >> 15);
+        *curr++ = value + (kPrime & mask);
+    } while (curr < end);
     return 1;
-#undef cbd_3_do_bytes
 }
 
 /*
@@ -1153,11 +1103,11 @@ int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
     uint8_t input[ML_KEM_RANDOM_BYTES + 1];
 
     memcpy(input, seed, ML_KEM_RANDOM_BYTES);
-    while (rank-- > 0) {
+    do {
         input[ML_KEM_RANDOM_BYTES] = (*counter)++;
         if (!cbd(out++, input, mdctx, key))
             return 0;
-    }
+    } while (--rank > 0);
     return 1;
 }
 
@@ -1172,17 +1122,17 @@ int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
     uint8_t input[ML_KEM_RANDOM_BYTES + 1];
 
     memcpy(input, seed, ML_KEM_RANDOM_BYTES);
-    while (rank-- > 0) {
+    do {
         input[ML_KEM_RANDOM_BYTES] = (*counter)++;
         if (!cbd(out, input, mdctx, key))
             return 0;
         scalar_ntt(out++);
-    }
+    } while (--rank > 0);
     return 1;
 }
 
 /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
-#define CBD1(evp_type)  ((evp_type) == NID_ML_KEM_512 ? cbd_3 : cbd_2)
+#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
 
 /*
  * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
@@ -1285,7 +1235,7 @@ static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
     const uint8_t *rho = key->rho;
     const ML_KEM_VINFO *vinfo = key->vinfo;
 
-    vector_encode_12(out, key->t, vinfo->rank);
+    vector_encode(out, key->t, 12, vinfo->rank);
     memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
 }
 
@@ -1299,7 +1249,7 @@ static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
 {
     const ML_KEM_VINFO *vinfo = key->vinfo;
 
-    vector_encode_12(out, key->s, vinfo->rank);
+    vector_encode(out, key->s, 12, vinfo->rank);
     out += vinfo->vector_bytes;
     encode_pubkey(out, key);
     out += vinfo->pubkey_bytes;
@@ -1570,11 +1520,11 @@ free_storage(ML_KEM_KEY *key)
 const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
 {
     switch (evp_type) {
-    case NID_ML_KEM_512:
+    case EVP_PKEY_ML_KEM_512:
         return &vinfo_map[0];
-    case NID_ML_KEM_768:
+    case EVP_PKEY_ML_KEM_768:
         return &vinfo_map[1];
-    case NID_ML_KEM_1024:
+    case EVP_PKEY_ML_KEM_1024:
         return &vinfo_map[2];
     }
     return NULL;
@@ -1852,7 +1802,7 @@ int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
      * We stack-allocate these.
      */
 #   define case_encap_seed(bits)                                            \
-    case NID_ML_KEM_##bits:                                                 \
+    case EVP_PKEY_ML_KEM_##bits:                                            \
         {                                                                   \
             scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
                                                                             \
@@ -1931,7 +1881,7 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
      * We stack-allocate these.
      */
 #   define case_decap(bits)                                             \
-    case NID_ML_KEM_##bits:                                             \
+    case EVP_PKEY_ML_KEM_##bits:                                        \
         {                                                               \
             uint8_t cbuf[CTEXT_BYTES(bits)];                            \
             scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
index 837f14626b5e20da978678c2ba2afc251e368e80..8f5b2fa86b98708d61751fd6f423fb09092b9da5 100644 (file)
@@ -893,9 +893,15 @@ static int ml_kem_spki_pub_to_der(const void *vkey, unsigned char **pder,
     publen = key->vinfo->pubkey_bytes;
 
     if (pder != NULL
-        && ((*pder = OPENSSL_malloc(publen)) == NULL
-            || !ossl_ml_kem_encode_public_key(*pder, publen, key)))
+        && (*pder = OPENSSL_malloc(publen)) == NULL)
         return 0;
+    if (!ossl_ml_kem_encode_public_key(*pder, publen, key)) {
+        ERR_raise_data(ERR_LIB_OSSL_ENCODER, ERR_R_INTERNAL_ERROR,
+                       "error encoding %s public key",
+                       key->vinfo->algorithm_name);
+        OPENSSL_free(*pder);
+        return 0;
+    }
 
     return publen;
 }
index e27823163682c89fc9645682ab5689a48a293525..dfb39e3813acdda4be374a7cc194e7738a28dd40 100644 (file)
@@ -640,12 +640,13 @@ static void *ml_kem_dup(const void *vkey, int selection)
     static void *ml_kem_##bits##_new(void *provctx) \
     { \
         return ml_kem_new(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), \
-                          NULL, NID_ML_KEM_##bits); \
+                          NULL, EVP_PKEY_ML_KEM_##bits); \
     } \
     static void *ml_kem_##bits##_gen_init(void *provctx, int selection, \
                                           const OSSL_PARAM params[]) \
     { \
-        return ml_kem_gen_init(provctx, selection, params, NID_ML_KEM_##bits); \
+        return ml_kem_gen_init(provctx, selection, params, \
+                               EVP_PKEY_ML_KEM_##bits); \
     } \
     const OSSL_DISPATCH ossl_ml_kem_##bits##_keymgmt_functions[] = { \
         { OSSL_FUNC_KEYMGMT_NEW, (OSSL_FUNC) ml_kem_##bits##_new }, \
index efd766bb79bc52d6040716ec4bfb48e6ecd7293b..5b9fb6124794dff1a00ab139ed2eef6ad4989a26 100644 (file)
@@ -44,11 +44,11 @@ static const int minimal_selection = OSSL_KEYMGMT_SELECT_DOMAIN_PARAMETERS
 
 /* Must match DECLARE_DISPATCH invocations at the end of the file */
 static const ECDH_VINFO hybrid_vtable[] = {
-    { "EC",  "P-256", 65, 32, 32, 1, NID_ML_KEM_768 },
-    { "EC",  "P-384", 97, 48, 48, 1, NID_ML_KEM_1024 },
+    { "EC",  "P-256", 65, 32, 32, 1, EVP_PKEY_ML_KEM_768 },
+    { "EC",  "P-384", 97, 48, 48, 1, EVP_PKEY_ML_KEM_1024 },
 #if !defined(OPENSSL_NO_ECX)
-    { "X25519", NULL, 32, 32, 32, 0, NID_ML_KEM_768 },
-    { "X448",   NULL, 56, 56, 56, 0, NID_ML_KEM_1024 },
+    { "X25519", NULL, 32, 32, 32, 0, EVP_PKEY_ML_KEM_768 },
+    { "X448",   NULL, 56, 56, 56, 0, EVP_PKEY_ML_KEM_1024 },
 #endif
 };
 
index f55e1cf48e1e312f2e8df573581b5b9c89e84f1a..f46da9d692acb03840fd382bbfaeeb2e58512596 100644 (file)
@@ -217,7 +217,11 @@ static int test_ml_kem(void)
 
 static int test_non_derandomised_ml_kem(void)
 {
-    static const int alg[3] = { NID_ML_KEM_512, NID_ML_KEM_768, NID_ML_KEM_1024 };
+    static const int alg[3] = {
+        EVP_PKEY_ML_KEM_512,
+        EVP_PKEY_ML_KEM_768,
+        EVP_PKEY_ML_KEM_1024
+    };
     EVP_RAND_CTX *privctx;
     EVP_RAND_CTX *pubctx;
     EVP_MD *sha256;
index 0f847a12d40fdd538740dde1fd4a722ba74d07c6..78b2168f81a7026b812be1e981f4155952d2ff8c 100644 (file)
@@ -92,7 +92,11 @@ static uint8_t ml_kem_expected_shared_secret[3][32] = {
 
 static int sanity_test(void)
 {
-    static const int alg[3] = { NID_ML_KEM_512, NID_ML_KEM_768, NID_ML_KEM_1024 };
+    static const int alg[3] = {
+        EVP_PKEY_ML_KEM_512,
+        EVP_PKEY_ML_KEM_768,
+        EVP_PKEY_ML_KEM_1024
+    };
     EVP_RAND_CTX *privctx;
     EVP_RAND_CTX *pubctx;
     EVP_MD *sha256 = EVP_MD_fetch(NULL, "sha256", NULL);