]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
ML-KEM libcrypto implementation polish
authorViktor Dukhovni <openssl-users@dukhovni.org>
Sat, 21 Dec 2024 16:07:33 +0000 (03:07 +1100)
committerTomas Mraz <tomas@openssl.org>
Fri, 14 Feb 2025 09:50:57 +0000 (10:50 +0100)
* Core ML_KEM constants in new <openssl/ml_kem.h>

* Renamed variant ordinals to ML_KEM_<bits>_VARIANT, freeing
  up the unadorned ML_KEM_<bits> names.

* Fewer/cleaner macros in <crypto/ml_kem.h>

* Fewer/cleaner macros for setting up the ML_KEM_VINFO table.

* Made (d, z) be separate inputs to the now single key generation
  function.  Both or neither have to be NULL.  This supports potential
  future callers that store them in a different order, or in separate
  buffers.

    - Random values are chosen when both are NULL, we never return the
      generated seeds, rather we may, when/if (d, z) private key support
      is added, store these in the expanded key, and make them available
      for import/export.

* No need for a stand-by keygen encoded public key buffer when the
  caller does not provide one (will ask for it later if needed).
  New `hash_h_pubkey` function can compute the public hash from
  the expanded form in constant space (384 bytes for 12-bit encoded
  scalar).

* Simplified code in `scalar_mult`.

* New `scalar_mult_add` adds the product to an existing scalar.
  Used in new `matrix_mult_transpose_add` replacing `matrix_mult_transpose`.

* Unrolled loop in `encode_12`.

* Folded decompression and inverse NTT into vecode_decode, the three
  were always used together.

* Folded inverse NTT into former `matrix_mult` as `matrix_mult_intt`,
  always used together.

* New gencbd_vector_ntt combines CBD vector generation with inverse NTT
  in one pass.

* All this makes for more readable code in `decrypt_cpa` and especially
  `genkey()`, which no longer requires caller-allocated variant-specific
  temporary storage (just a single EVP_MD_CTX is still needed).

Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Neil Horman <nhorman@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/26236)

crypto/ml_kem/ml_kem.c
include/crypto/ml_kem.h
include/crypto/types.h
include/openssl/ml_kem.h [new file with mode: 0644]
providers/common/capabilities.c
providers/implementations/keymgmt/ml_kem_kmgmt.c
test/ml_kem_evp_extra_test.c
test/ml_kem_internal_test.c

index ca75cbe7b84bad2b86235757f6629624bcc89265..df30c26d447f2edc5a15deb951fb9edfa1163cd6 100644 (file)
 #define bit0(b)     ((b) & 1)
 #define bitn(n, b)  (((b) >> n) & 1)
 
-#define DEGREE              ML_KEM_DEGREE
-#define BARRETT_SHIFT       (2 * ML_KEM_LOG2PRIME)
+/*
+ * 12 bits are sufficient to losslessly represent values in [0, q-1].
+ * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
+ */
+#define DEGREE          ML_KEM_DEGREE
+#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
+#define LOG2PRIME       12
+#define BARRETT_SHIFT   (2 * LOG2PRIME)
+
 #ifdef SHA3_BLOCKSIZE
-# define SHAKE128_BLOCKSIZE  SHA3_BLOCKSIZE(128)
+# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
 #endif
 
 /*
@@ -82,12 +89,12 @@ typedef struct ossl_ml_kem_scalar_st {
     uint16_t c[ML_KEM_DEGREE];
 } scalar;
 
-/* General form of public and private key storage */
+/* Key material allocation layout */
 #define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
-    struct ossl_ml_kem_##name##_st { \
+    struct name##_alloc { \
         /* Public vector |t| */ \
         scalar tbuf[(rank)]; \
-        /* Pre-computed matrix |m| */ \
+        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
         scalar mbuf[(rank)*(rank)] \
         /* optional private key data */ \
         private_sz \
@@ -95,10 +102,10 @@ typedef struct ossl_ml_kem_scalar_st {
 
 /* Declare variant-specific public and private storage */
 #define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
-    DECLARE_ML_KEM_KEYDATA(bits##_puballoc, ML_KEM_##bits##_RANK,;); \
-    DECLARE_ML_KEM_KEYDATA(bits##_prvalloc, ML_KEM_##bits##_RANK,;\
+    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
+    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
         scalar sbuf[ML_KEM_##bits##_RANK]; \
-        uint8_t zbuf[ML_KEM_RANDOM_BYTES];)
+        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
 DECLARE_ML_KEM_VARIANT_KEYDATA(512);
 DECLARE_ML_KEM_VARIANT_KEYDATA(768);
 DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
@@ -108,31 +115,20 @@ DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
 typedef __owur
 int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
                 EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
-
-/* The wire form of a losslessly encoded vector (12-bits per element). */
-#define ML_KEM_VECTOR_BYTES(rank) \
-    ((3 * ML_KEM_DEGREE / 2) * (rank))
-
-/*
- * The expanded internal form stores each coefficient as a 16-bit unsigned int
- */
-#define ML_KEM_VECALLOC_BYTES(rank) \
-    (2 * ML_KEM_DEGREE * (rank))
-
-/*
- * The wire-form public key consists of the lossless encoding of the vector
- * "t" = "A" * "s" + "e", followed by public seed "rho".
- */
-#define ML_KEM_PUBKEY_BYTES(rank) \
-    (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES)
+static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s);
 
 /*
- * Our internal serialised private key concatenates serialisations of "s", the
- * public key, the public key hash, and the failure secret "z".
+ * The wire-form of a losslessly encoded vector uses 12-bits per element.
+ *
+ * The wire-form public key consists of the lossless encoding of the public
+ * vector |t|, followed by the public seed |rho|.
+ *
+ * Our serialised private key concatenates serialisations of the private vector
+ * |s|, the public key, the public key hash, and the failure secret |z|.
  */
-#define ML_KEM_PRVKEY_BYTES(rank) \
-    (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \
-     + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES)
+#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
+#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
+#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
 
 /*
  * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
@@ -140,57 +136,9 @@ int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
  * "dv" bits, respectively.  This encoding is the ciphertext input for
  * decapsulation.
  */
-#define ML_KEM_U_VECTOR_BYTES(rank, du) \
-    ((ML_KEM_DEGREE / 8) * (du) * (rank))
-#define ML_KEM_V_SCALAR_BYTES(dv) \
-    ((ML_KEM_DEGREE / 8) * (dv))
-#define ML_KEM_CTEXT_BYTES(rank, du, dv) \
-    (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv))
-
-/*
- * Variant-specific sizes
- */
-#define ML_KEM_512_VECTOR_BYTES \
-    ML_KEM_VECTOR_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_VECTOR_BYTES \
-    ML_KEM_VECTOR_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_VECTOR_BYTES \
-    ML_KEM_VECTOR_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_PUBLIC_KEY_BYTES \
-    ML_KEM_PUBKEY_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_PUBLIC_KEY_BYTES \
-    ML_KEM_PUBKEY_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_PUBLIC_KEY_BYTES \
-    ML_KEM_PUBKEY_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_PRIVATE_KEY_BYTES \
-    ML_KEM_PRVKEY_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_PRIVATE_KEY_BYTES \
-    ML_KEM_PRVKEY_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_PRIVATE_KEY_BYTES \
-    ML_KEM_PRVKEY_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_U_VECTOR_BYTES \
-    ML_KEM_U_VECTOR_BYTES(ML_KEM_512_RANK, ML_KEM_512_DU)
-#define ML_KEM_768_U_VECTOR_BYTES \
-    ML_KEM_U_VECTOR_BYTES(ML_KEM_768_RANK, ML_KEM_768_DU)
-#define ML_KEM_1024_U_VECTOR_BYTES \
-    ML_KEM_U_VECTOR_BYTES(ML_KEM_1024_RANK, ML_KEM_1024_DU)
-
-#define ML_KEM_512_V_SCALAR_BYTES \
-    ML_KEM_V_SCALAR_BYTES(ML_KEM_512_DV)
-#define ML_KEM_768_V_SCALAR_BYTES \
-    ML_KEM_V_SCALAR_BYTES(ML_KEM_768_DV)
-#define ML_KEM_1024_V_SCALAR_BYTES \
-    ML_KEM_V_SCALAR_BYTES(ML_KEM_1024_DV)
-
-#define ML_KEM_512_CIPHERTEXT_BYTES \
-    (ML_KEM_512_U_VECTOR_BYTES + ML_KEM_512_V_SCALAR_BYTES)
-#define ML_KEM_768_CIPHERTEXT_BYTES \
-    (ML_KEM_768_U_VECTOR_BYTES + ML_KEM_768_V_SCALAR_BYTES)
-#define ML_KEM_1024_CIPHERTEXT_BYTES \
-    (ML_KEM_1024_U_VECTOR_BYTES + ML_KEM_1024_V_SCALAR_BYTES)
+#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
+#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
+#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
 
 /*
  * Per-variant fixed parameters
@@ -198,51 +146,51 @@ int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
 static const ML_KEM_VINFO vinfo_map[3] = {
     {
         "ML-KEM-512",
-        ML_KEM_512_VECTOR_BYTES,
-        ML_KEM_512_PRIVATE_KEY_BYTES,
-        ML_KEM_512_PUBLIC_KEY_BYTES,
-        ML_KEM_512_CIPHERTEXT_BYTES,
-        ML_KEM_512_U_VECTOR_BYTES,
-        sizeof(struct ossl_ml_kem_512_puballoc_st),
-        sizeof(struct ossl_ml_kem_512_prvalloc_st),
-        ML_KEM_512,
-        512,
+        PRVKEY_BYTES(512),
+        sizeof(struct prvkey_512_alloc),
+        PUBKEY_BYTES(512),
+        sizeof(struct pubkey_512_alloc),
+        CTEXT_BYTES(512),
+        VECTOR_BYTES(512),
+        U_VECTOR_BYTES(512),
+        ML_KEM_512_VARIANT,
+        ML_KEM_512_BITS,
         ML_KEM_512_RANK,
         ML_KEM_512_DU,
         ML_KEM_512_DV,
-        ML_KEM_512_RNGSEC
+        ML_KEM_512_SECBITS
     },
     {
         "ML-KEM-768",
-        ML_KEM_768_VECTOR_BYTES,
-        ML_KEM_768_PRIVATE_KEY_BYTES,
-        ML_KEM_768_PUBLIC_KEY_BYTES,
-        ML_KEM_768_CIPHERTEXT_BYTES,
-        ML_KEM_768_U_VECTOR_BYTES,
-        sizeof(struct ossl_ml_kem_768_puballoc_st),
-        sizeof(struct ossl_ml_kem_768_prvalloc_st),
-        ML_KEM_768,
-        768,
+        PRVKEY_BYTES(768),
+        sizeof(struct prvkey_768_alloc),
+        PUBKEY_BYTES(768),
+        sizeof(struct pubkey_768_alloc),
+        CTEXT_BYTES(768),
+        VECTOR_BYTES(768),
+        U_VECTOR_BYTES(768),
+        ML_KEM_768_VARIANT,
+        ML_KEM_768_BITS,
         ML_KEM_768_RANK,
         ML_KEM_768_DU,
         ML_KEM_768_DV,
-        ML_KEM_768_RNGSEC
+        ML_KEM_768_SECBITS
     },
     {
         "ML-KEM-1024",
-        ML_KEM_1024_VECTOR_BYTES,
-        ML_KEM_1024_PRIVATE_KEY_BYTES,
-        ML_KEM_1024_PUBLIC_KEY_BYTES,
-        ML_KEM_1024_CIPHERTEXT_BYTES,
-        ML_KEM_1024_U_VECTOR_BYTES,
-        sizeof(struct ossl_ml_kem_1024_puballoc_st),
-        sizeof(struct ossl_ml_kem_1024_prvalloc_st),
-        ML_KEM_1024,
-        1024,
+        PRVKEY_BYTES(1024),
+        sizeof(struct prvkey_1024_alloc),
+        PUBKEY_BYTES(1024),
+        sizeof(struct pubkey_1024_alloc),
+        CTEXT_BYTES(1024),
+        VECTOR_BYTES(1024),
+        U_VECTOR_BYTES(1024),
+        ML_KEM_1024_VARIANT,
+        ML_KEM_1024_BITS,
         ML_KEM_1024_RANK,
         ML_KEM_1024_DU,
         ML_KEM_1024_DV,
-        ML_KEM_1024_RNGSEC
+        ML_KEM_1024_SECBITS
     }
 };
 
@@ -255,7 +203,7 @@ static const int kPrime = ML_KEM_PRIME;
 static const unsigned int kBarrettShift = BARRETT_SHIFT;
 static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
 static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
-static const uint16_t kInverseDegree = ML_KEM_INVERSE_DEGREE;
+static const uint16_t kInverseDegree = INVERSE_DEGREE;
 
 /*
  * Python helper:
@@ -366,6 +314,32 @@ int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
         && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
 }
 
+/* Incremental hash_h of expanded public key */
+static int
+hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
+              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
+{
+    const ML_KEM_VINFO *vinfo = key->vinfo;
+    const scalar *t = key->t, *end = t + vinfo->rank;
+    unsigned int sz;
+
+    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
+        return 0;
+
+    while (t < end) {
+        uint8_t buf[3 * DEGREE / 2];
+
+        scalar_encode_12(buf, t++);
+        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
+            return 0;
+    }
+
+    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
+        return 0;
+    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
+        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
+}
+
 /*
  * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
  * length input, producing 64 bytes of output, in particular the seeds
@@ -579,29 +553,39 @@ static void scalar_sub(scalar *lhs, const scalar *rhs)
 static void scalar_mult(scalar *out, const scalar *lhs,
                         const scalar *rhs)
 {
-    int i;
-    uint32_t real_real, img_img, real_img, img_real;
-
-    for (i = 0; i < DEGREE / 2; i++) {
-        real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
-        img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
-        real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
-        img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
-        out->c[2 * i] =
-            reduce(real_real +
-                   (uint32_t)reduce(img_img) * kModRoots[i]);
-        out->c[2 * i + 1] = reduce(img_real + real_img);
+    uint16_t *curr = out->c, *end = curr + DEGREE;
+    const uint16_t *lc = lhs->c, *rc = rhs->c;
+    const uint16_t *roots = kModRoots;
+
+    while (curr < end) {
+        uint32_t l0 = *lc++, r0 = *rc++;
+        uint32_t l1 = *lc++, r1 = *rc++;
+        uint32_t zetapow = *roots++;
+
+        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
+        *curr++ = reduce(l0 * r1 + l1 * r0);
     }
 }
 
+/* Above, but add the result to an existing scalar */
 static ossl_inline
 void scalar_mult_add(scalar *out, const scalar *lhs,
                      const scalar *rhs)
 {
-    scalar product;
+    uint16_t *curr = out->c, *end = curr + DEGREE;
+    const uint16_t *lc = lhs->c, *rc = rhs->c;
+    const uint16_t *roots = kModRoots;
 
-    scalar_mult(&product, lhs, rhs);
-    scalar_add(out, &product);
+    while (curr < end) {
+        uint32_t l0 = *lc++, r0 = *rc++;
+        uint32_t l1 = *lc++, r1 = *rc++;
+        uint16_t *c0 = curr++;
+        uint16_t *c1 = curr++;
+        uint32_t zetapow = *roots++;
+
+        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
+        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
+    }
 }
 
 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
@@ -648,17 +632,23 @@ static void scalar_encode(uint8_t *out, const scalar *s, int bits)
  */
 static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s)
 {
-    const uint16_t *c = s->c;
-    int i;
+    const uint16_t *curr = s->c, *end = curr + DEGREE;
+    uint16_t c1, c2;
 
-    for (i = 0; i < DEGREE / 2; ++i) {
-        uint16_t c1 = *c++;
-        uint16_t c2 = *c++;
+#define encode_two                                                      \
+        c1 = *curr++;                                                   \
+        c2 = *curr++;                                                   \
+        *out++ = (uint8_t) c1;                                          \
+        *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4));   \
+        *out++ = (uint8_t) (c2 >> 4)
 
-        *out++ = (uint8_t) c1;
-        *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4));
-        *out++ = (uint8_t) (c2 >> 4);
+    while (curr < end) {
+        encode_two;
+        encode_two;
+        encode_two;
+        encode_two;
     }
+#undef encode_two
 }
 
 /*
@@ -885,20 +875,25 @@ static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
 }
 
 /*
- * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns one on
- * success or zero if any parsed value is >= |ML_KEM_PRIME|.
+ * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
+ * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
+ * then decompressed and transformed via the NTT.
  *
  * Note: Used only in decrypt_cpa(), which returns void and so does not check
  * the return value of this function.  Side-channels are fine when the input
  * ciphertext to decap() is simply syntactically invalid.
  */
-static void vector_decode(scalar *out, const uint8_t *in, int bits, int rank)
+static void
+vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
 {
     int stride = bits * DEGREE / 8;
 
-    for (; rank-- > 0; in += stride)
-        if (!scalar_decode(out++, in, bits))
+    for (; rank-- > 0; in += stride, ++out) {
+        if (!scalar_decode(out, in, bits))
             return;
+        scalar_decompress(out, bits);
+        scalar_ntt(out);
+    }
 }
 
 /* vector_encode(), specialised to bits == 12. */
@@ -930,13 +925,6 @@ static void vector_compress(scalar *a, int bits, int rank)
         scalar_compress(a++, bits);
 }
 
-/* In-place decompression of each scalar component */
-static void vector_decompress(scalar *a, int bits, int rank)
-{
-    while (rank-- > 0)
-        scalar_decompress(a++, bits);
-}
-
 /* The output scalar must not overlap with the inputs */
 static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
                           int rank)
@@ -946,23 +934,12 @@ static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
         scalar_mult_add(out, ++lhs, ++rhs);
 }
 
-/* In-place NTT transform of a vector */
-static void vector_ntt(scalar *a, int rank)
-{
-    while (rank-- > 0)
-        scalar_ntt(a++);
-}
-
-/* In-place inverse NTT transform of a vector */
-static void vector_inverse_ntt(scalar *a, int rank)
-{
-    while (rank-- > 0)
-        scalar_inverse_ntt(a++);
-}
-
-/* Here, the output vector must not overlap with the inputs */
+/*
+ * Here, the output vector must not overlap with the inputs, the result is
+ * directly subjected to inverse NTT.
+ */
 static void
-matrix_mult(scalar *out, const scalar *m, const scalar *a, int rank)
+matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
 {
     const scalar *ar;
     int i, j;
@@ -971,18 +948,19 @@ matrix_mult(scalar *out, const scalar *m, const scalar *a, int rank)
         scalar_mult(out, m++, ar = a);
         for (j = rank - 1; j > 0; --j)
             scalar_mult_add(out, m++, ++ar);
+        scalar_inverse_ntt(out);
     }
 }
 
 /* Here, the output vector must not overlap with the inputs */
 static void
-matrix_mult_transpose(scalar *out, const scalar *m, const scalar *a, int rank)
+matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
 {
     const scalar *mc = m, *mr, *ar;
     int i, j;
 
-    for (i = rank; i-- > 0; ++out)  {
-        scalar_mult(out, mr = mc++, ar = a);
+    for (i = rank; i-- > 0; ++out) {
+        scalar_mult_add(out, mr = mc++, ar = a);
         for (j = rank; --j > 0; )
             scalar_mult_add(out, (mr += rank), ++ar);
     }
@@ -1147,8 +1125,28 @@ int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
     return 1;
 }
 
+/*
+ * As above plus NTT transform.
+ */
+static __owur
+int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
+                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
+                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
+{
+    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
+
+    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
+    while (rank-- > 0) {
+        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
+        if (!cbd(out, input, mdctx, key))
+            return 0;
+        scalar_ntt(out++);
+    }
+    return 1;
+}
+
 /* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
-static CBD_FUNC const cbd1[ML_KEM_1024 + 1] = { cbd_3, cbd_2, cbd_2 };
+static CBD_FUNC const cbd1[ML_KEM_1024_VARIANT + 1] = { cbd_3, cbd_2, cbd_2 };
 
 /*
  * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
@@ -1186,15 +1184,13 @@ int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
     int dv = vinfo->dv;
 
     /* FIPS 203 "y" vector */
-    if (!gencbd_vector(y, cbd_1, &counter, r, rank, mdctx, key))
+    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
         return 0;
-    vector_ntt(y, rank);
     /* FIPS 203 "v" scalar */
     inner_product(&v, key->t, y, rank);
     scalar_inverse_ntt(&v);
     /* FIPS 203 "u" vector */
-    matrix_mult(u, key->m, y, rank);
-    vector_inverse_ntt(u, rank);
+    matrix_mult_intt(u, key->m, y, rank);
 
     /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
     if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
@@ -1230,9 +1226,7 @@ decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
     int du = vinfo->du;
     int dv = vinfo->dv;
 
-    vector_decode(u, ctext, du, rank);
-    vector_decompress(u, du, rank);
-    vector_ntt(u, rank);
+    vector_decode_decompress_ntt(u, ctext, du, rank);
     scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
     scalar_decompress(&v, dv);
     inner_product(&mask, key->s, u, rank);
@@ -1337,14 +1331,13 @@ static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
  * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
  * inlined.
  *
- * The caller MUST also pass a pre-allocated scratch buffer |tmp| with room for
- * at least one "vector" (rank * sizeof(scalar)), and a pre-allocated digest
- * context that is not shared with any concurrent computation.
+ * The caller MUST pass a pre-allocated digest context that is not shared with
+ * any concurrent computation.
  *
- * This function outputs the serialised wire-form |ek| public key into  the
- * provided |pubenc| buffer, and generates the content of the |rho|, |pkhash|,
- * |t|, |m|, |s| and |z| components of the private |key| (which must have
- * preallocated space for these).
+ * This function optionally outputs the serialised wire-form |ek| public key
+ * into the provided |pubenc| buffer, and generates the content of the |rho|,
+ * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
+ * have preallocated space for these).
  *
  * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
  * domain separation.  These are concatenated and hashed to produce a pair of
@@ -1360,8 +1353,7 @@ static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
 static __owur
 int genkey(const uint8_t d[ML_KEM_RANDOM_BYTES],
            const uint8_t z[ML_KEM_RANDOM_BYTES],
-           scalar *tmp, EVP_MD_CTX *mdctx,
-           uint8_t *pubenc, ML_KEM_KEY *key)
+           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
 {
     uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
     const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
@@ -1380,24 +1372,29 @@ int genkey(const uint8_t d[ML_KEM_RANDOM_BYTES],
     if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
         return 0;
     memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
+    /* FIPS 203 |e| vector is initial value of key->t */
     if (!matrix_expand(mdctx, key)
-        || !gencbd_vector(key->s, cbd_1, &counter, sigma, rank, mdctx, key))
-        return 0;
-    vector_ntt(key->s, rank);
-    /* FIPS 203 |e| vector */
-    if (!gencbd_vector(tmp, cbd_1, &counter, sigma, rank, mdctx, key))
+        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
+        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
         return 0;
-    vector_ntt(tmp, rank);
 
-    /* Fill in the public key */
-    matrix_mult_transpose(key->t, key->m, key->s, rank);
-    vector_add(key->t, tmp, rank);
-    encode_pubkey(pubenc, key);
-    if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
-        return 0;
+    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
+    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
+
+    if (pubenc == NULL) {
+        /* Incremental digest of public key without in-full serialisation. */
+        if (!hash_h_pubkey(key->pkhash, mdctx, key))
+            return 0;
+    } else {
+        encode_pubkey(pubenc, key);
+        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
+           return 0;
+    }
 
-    /* Save "z" portion of seed for "implicit rejection" on failure */
+    /* Save |z| portion of seed for "implicit rejection" on failure. */
     memcpy(key->z, z, ML_KEM_RANDOM_BYTES);
+    /* Also save |d| portion, in suport of likely alternative key format. */
+    memcpy(key->d = key->z + ML_KEM_RANDOM_BYTES, d, ML_KEM_RANDOM_BYTES);
     return 1;
 }
 
@@ -1496,11 +1493,15 @@ int add_storage(scalar *p, int private, ML_KEM_KEY *key)
         return 0;
     /* A public key needs space for |t| and |m| */
     key->m = (key->t = p) + rank;
-    /* A private key also needs space for |s| and |z| */
+    /*
+     * A private key also needs space for |s| and |z|.
+     * The |z| buffer always includes additional space for |d|, but a key's |d|
+     * pointer is left NULL when parsed from the NIST format, which omits that
+     * information.  Only keys generated from a (d, z) seed pair will have a
+     * non-NULL |d| pointer.
+     */
     if (private)
         key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
-    else
-        key->z = (uint8_t *)(key->s = NULL);
     return 1;
 }
 
@@ -1514,7 +1515,7 @@ free_storage(ML_KEM_KEY *key)
     if (key->t == NULL)
         return;
     OPENSSL_free(key->t);
-    key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
+    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
 }
 
 /*
@@ -1527,7 +1528,7 @@ free_storage(ML_KEM_KEY *key)
 /* Retrieve the parameters of one of the ML-KEM variants */
 const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant)
 {
-    if (variant > ML_KEM_1024)
+    if (variant > ML_KEM_1024_VARIANT)
         return NULL;
     return &vinfo_map[variant];
 }
@@ -1550,7 +1551,7 @@ ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
     key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
     key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
     key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
-    key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
+    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
 
     if (key->shake128_md != NULL
         && key->shake256_md != NULL
@@ -1587,6 +1588,9 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
         break;
     case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
         ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
+        /* Duplicated keys retain |d|, if available */
+        if (key->d != NULL)
+            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
         break;
     }
 
@@ -1616,10 +1620,11 @@ void ossl_ml_kem_key_free(ML_KEM_KEY *key)
     /*-
      * Cleanse any sensitive data:
      * - The private vector |s| is immediately followed by the FO failure
-     *   secret |z|, we can cleanse both in one call.
+     *   secret |z|, and seed |d|, we can cleanse all three in one call.
      */
     if (key->s != NULL)
-        OPENSSL_cleanse(key->s, key->vinfo->vector_bytes + ML_KEM_RANDOM_BYTES);
+        OPENSSL_cleanse(key->s,
+                        key->vinfo->vector_bytes + 2 * ML_KEM_RANDOM_BYTES);
 
     /* Free the key material */
     OPENSSL_free(key->t);
@@ -1673,7 +1678,7 @@ int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
     return ret;
 }
 
-/* Parse input as a new private key  */
+/* Parse input as a new private key */
 int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
                                   ML_KEM_KEY *key)
 {
@@ -1700,13 +1705,14 @@ int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
 }
 
 /*
- * Generate a new keypair from a given seed, giving a deterministic result for
- * running tests.  The caller can elect to not collect the encoded public key.
+ * Generate a new keypair either from the input seeds (when non-null), yielding
+ * a deterministic result for running tests, or securely generated random data.
  */
-int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seedlen,
-                            uint8_t *pubenc, size_t publen,
-                            ML_KEM_KEY *key)
+int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z,
+                       uint8_t *pubenc, size_t publen,
+                       ML_KEM_KEY *key)
 {
+    uint8_t tmpseed[2 * ML_KEM_RANDOM_BYTES];
     EVP_MD_CTX *mdctx = NULL;
     const ML_KEM_VINFO *vinfo;
     int ret = 0;
@@ -1715,74 +1721,29 @@ int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seedlen,
         return 0;
     vinfo = key->vinfo;
 
-    if (seed == NULL || seedlen != ML_KEM_SEED_BYTES
+    /* Both seeds or neither must be NULL */
+    if (((d == NULL) ^ (z == NULL)) != 0
         || (pubenc != NULL && publen != vinfo->pubkey_bytes)
         || (mdctx = EVP_MD_CTX_new()) == NULL)
         return 0;
 
-    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key)) {
-        const uint8_t *d = seed;
-        const uint8_t *z = seed + ML_KEM_RANDOM_BYTES;
-
-        /*-
-         * This avoids the need to handle allocation failures for two (max 2KB
-         * each) vectors and (if the caller does not want the public key) an
-         * encoded public key (max 1568 bytes), that are never retained on
-         * return from this function.
-         * We stack-allocate these.
-         */
-#       define case_genkey_seed(bits) \
-        case ML_KEM_##bits: \
-            if (pubenc != NULL) \
-            { \
-                scalar tmp[ML_KEM_##bits##_RANK]; \
-                                                             \
-                ret = genkey(d, z, tmp, mdctx, pubenc, key); \
-            } else { \
-                scalar tmp[ML_KEM_##bits##_RANK]; \
-                uint8_t encbuf[ML_KEM_##bits##_PUBLIC_KEY_BYTES]; \
-                                                             \
-                ret = genkey(d, z, tmp, mdctx, encbuf, key); \
-            } \
-            break
-        switch (vinfo->variant) {
-        case_genkey_seed(512);
-        case_genkey_seed(768);
-        case_genkey_seed(1024);
-        }
-#       undef case_genkey_seed
+    if (d == NULL) {
+        if (RAND_priv_bytes_ex(key->libctx, tmpseed, 2 * ML_KEM_RANDOM_BYTES,
+                              key->vinfo->secbits) <= 0)
+            return 0;
+        d = tmpseed;
+        z = tmpseed + ML_KEM_RANDOM_BYTES;
     }
 
+    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
+        ret = genkey(d, z, mdctx, pubenc, key);
+
     if (!ret)
         free_storage(key);
     EVP_MD_CTX_free(mdctx);
     return ret;
 }
 
-/*
- * Generate a new keypair from a random seed, using the library context's
- * private DRBG.  The caller can elect to not collect the seed or the encoded
- * public key.
- */
-int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen,
-                            uint8_t *pubenc, size_t publen,
-                            ML_KEM_KEY *key)
-{
-    uint8_t tmpseed[ML_KEM_SEED_BYTES];
-    uint8_t *sptr = seed == NULL ? tmpseed : seed;
-
-    if (key == NULL
-        || ossl_ml_kem_have_pubkey(key)
-        || (seed != NULL && seedlen != sizeof(tmpseed)))
-        return 0;
-
-    if (RAND_priv_bytes_ex(key->libctx, sptr, sizeof(tmpseed),
-                           key->vinfo->secbits) <= 0)
-        return 0;
-
-    return ossl_ml_kem_genkey_seed(sptr, sizeof(tmpseed), pubenc, publen, key);
-}
-
 /*
  * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
  * This is the deterministic version with randomness supplied externally.
@@ -1812,7 +1773,7 @@ int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
      * We stack-allocate these.
      */
 #   define case_encap_seed(bits) \
-    case ML_KEM_##bits: \
+    case ML_KEM_##bits##_VARIANT: \
         { \
             scalar tmp[2 * ML_KEM_##bits##_RANK]; \
                                                                          \
@@ -1874,15 +1835,15 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
      * retained on return from this function.
      * We stack-allocate these.
      */
-#   define case_decap(bits) \
-    case ML_KEM_##bits: \
-        { \
-            uint8_t cbuf[ML_KEM_##bits##_CIPHERTEXT_BYTES]; \
-            scalar tmp[2 * ML_KEM_##bits##_RANK]; \
-                                                                      \
-            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
-            EVP_MD_CTX_free(mdctx); \
-            return ret; \
+#   define case_decap(bits)                                             \
+    case ML_KEM_##bits##_VARIANT:                                       \
+        {                                                               \
+            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
+            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
+                                                                        \
+            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
+            EVP_MD_CTX_free(mdctx);                                     \
+            return ret;                                                 \
         }
     switch (vinfo->variant) {
     case_decap(512);
index e4d6dffac46857395539f8765b269205e54a1a47..0604b967633c070775b23c4c5ad224890ff80cc2 100644 (file)
@@ -1,6 +1,6 @@
 /*
  * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
-
+ *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
  * in the file LICENSE in the source distribution or at
@@ -16,7 +16,6 @@
 # include <openssl/e_os2.h>
 # include <openssl/core_dispatch.h>
 # include <crypto/evp.h>
-# include <crypto/types.h>
 
 # define ML_KEM_DEGREE 256
 /*
  * of unity, the polynomial (X^256+1) splits in Z_q[X] into 128 irreducible
  * quadratic factors of the form (X^2 - zeta^(2i + 1)).  This is used to
  * implement efficient multiplication in the ring R_q via the "NTT" transform.
- *
- * 12 bits are sufficient to losslessly represent values in [0, q-1].
- * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
  */
 # define ML_KEM_PRIME          (ML_KEM_DEGREE * 13 + 1)
-# define ML_KEM_LOG2PRIME      12
-# define ML_KEM_INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
 
 /*
  * Various ML-KEM primitives need random input, 32-bytes at a time.  Key
  * - "secbits" is required security strength of the RNG for the random inputs.
  */
 
-/*
- * The wire form of a losslessly encoded vector (12-bits per element)
- */
-# define ML_KEM_VECTOR_BYTES(rank) \
-    ((3 * ML_KEM_DEGREE / 2) * (rank))
-
-/*
- * The wire-form public key consists of the lossless encoding of the vector
- * "t" = "A" * "s" + "e", followed by public seed "rho".
- */
-# define ML_KEM_PUBKEY_BYTES(rank) \
-    (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES)
-
-/*
- * Our internal serialised private key concatenates serialisations of "s", the
- * public key, the public key hash, and the failure secret "z".
- */
-# define ML_KEM_PRVKEY_BYTES(rank) \
-    (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \
-     + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES)
-
-/*
- * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
- * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
- * "dv" bits, respectively.  This encoding is the ciphertext input for
- * decapsulation.
- */
-# define ML_KEM_U_VECTOR_BYTES(rank, du) \
-    ((ML_KEM_DEGREE / 8) * (du) * (rank))
-# define ML_KEM_V_SCALAR_BYTES(dv) \
-    ((ML_KEM_DEGREE / 8) * (dv))
-# define ML_KEM_CTEXT_BYTES(rank, du, dv) \
-    (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv))
-
 /*
  * Variant-specific constants and structures
  * -----------------------------------------
  */
-# define ML_KEM_512_RANK       2
-# define ML_KEM_512_ETA1       3
-# define ML_KEM_512_ETA2       2
-# define ML_KEM_512_DU         10
-# define ML_KEM_512_DV         4
-# define ML_KEM_512_RNGSEC     128
-
-# define ML_KEM_768_RANK       3
-# define ML_KEM_768_ETA1       2
-# define ML_KEM_768_ETA2       2
-# define ML_KEM_768_DU         10
-# define ML_KEM_768_DV         4
-# define ML_KEM_768_RNGSEC     192
-
-# define ML_KEM_1024_RANK      4
-# define ML_KEM_1024_ETA1      2
-# define ML_KEM_1024_ETA2      2
-# define ML_KEM_1024_DU        11
-# define ML_KEM_1024_DV        5
-# define ML_KEM_1024_RNGSEC    256
+# define ML_KEM_512_VARIANT     0
+# define ML_KEM_512_BITS        512
+# define ML_KEM_512_RANK        2
+# define ML_KEM_512_ETA1        3
+# define ML_KEM_512_ETA2        2
+# define ML_KEM_512_DU          10
+# define ML_KEM_512_DV          4
+# define ML_KEM_512_SECBITS     128
+
+# define ML_KEM_768_VARIANT     1
+# define ML_KEM_768_BITS        768
+# define ML_KEM_768_RANK        3
+# define ML_KEM_768_ETA1        2
+# define ML_KEM_768_ETA2        2
+# define ML_KEM_768_DU          10
+# define ML_KEM_768_DV          4
+# define ML_KEM_768_SECBITS     192
+
+# define ML_KEM_1024_VARIANT    2
+# define ML_KEM_1024_BITS       1024
+# define ML_KEM_1024_RANK       4
+# define ML_KEM_1024_ETA1       2
+# define ML_KEM_1024_ETA2       2
+# define ML_KEM_1024_DU         11
+# define ML_KEM_1024_DV         5
+# define ML_KEM_1024_SECBITS    256
 
 /*
  * External variant-specific API
  * -----------------------------
  */
 
-/*
- * Each variant parameter set is associated with an ordinal number which
- * represents that parameter set.
- */
-# define ML_KEM_512     0
-# define ML_KEM_768     1
-# define ML_KEM_1024    2
-
 typedef struct {
     const char *algorithm_name;
-    size_t vector_bytes;
     size_t prvkey_bytes;
+    size_t prvalloc;
     size_t pubkey_bytes;
+    size_t puballoc;
     size_t ctext_bytes;
+    size_t vector_bytes;
     size_t u_vector_bytes;
-    size_t puballoc;
-    size_t prvalloc;
     int variant;
     int bits;
     int rank;
@@ -182,7 +140,7 @@ typedef struct {
 const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant);
 
 /* Known as ML_KEM_KEY via crypto/types.h */
-struct ossl_ml_kem_key_st {
+typedef struct ossl_ml_kem_key_st {
     /* Variant metadata, for one of ML-KEM-{512,768,1024} */
     const ML_KEM_VINFO *vinfo;
 
@@ -207,11 +165,12 @@ struct ossl_ml_kem_key_st {
     struct ossl_ml_kem_scalar_st *m;        /* Pre-computed pubkey matrix */
     struct ossl_ml_kem_scalar_st *s;        /* Private key secret vector */
     uint8_t *z;                             /* Private key FO failure secret */
+    uint8_t *d;                             /* Private key seed */
 
     /* Fixed-size/offset built-ins */
     uint8_t rho[ML_KEM_RANDOM_BYTES];       /* Matrix recovery seed */
     uint8_t pkhash[ML_KEM_PKHASH_BYTES];    /* Hash of wire-form public key */
-};
+} ML_KEM_KEY;
 
 /* The public key is always present, when the private is */
 # define ossl_ml_kem_key_vinfo(key)        ((key)->vinfo)
@@ -222,8 +181,6 @@ struct ossl_ml_kem_key_st {
  * ----- ML-KEM key lifecycle
  */
 
-# ifndef OPENSSL_NO_ML_KEM
-
 /*
  * Allocate a "bare" key for given ML-KEM variant. Initially without any public
  * or private key material.
@@ -257,13 +214,9 @@ __owur
 int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
                                   ML_KEM_KEY *key);
 __owur
-int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen,
-                            uint8_t *pubenc, size_t publen,
-                            ML_KEM_KEY *key);
-__owur
-int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seed_len,
-                            uint8_t *pubenc, size_t publen,
-                            ML_KEM_KEY *key);
+int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z,
+                       uint8_t *pubenc, size_t publen,
+                       ML_KEM_KEY *key);
 
 /*
  * Perform an ML-KEM operation with a given ML-KEM key.  The key can generally
@@ -295,6 +248,4 @@ int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
 __owur
 int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2);
 
-# endif /* OPENSSL_NO_ML_KEM */
-
 #endif  /* OPENSSL_HEADER_ML_KEM_H */
index c5e3d0effbb2940fb8539e100ea51472679389ed..ad17f052e45f5047ad498fa1c55538ee33bc93c4 100644 (file)
@@ -29,8 +29,4 @@ typedef struct dsa_st DSA;
 typedef struct ecx_key_st ECX_KEY;
 # endif
 
-# ifndef OPENSSL_NO_ML_KEM
-typedef struct ossl_ml_kem_key_st ML_KEM_KEY;
-# endif
-
 #endif
diff --git a/include/openssl/ml_kem.h b/include/openssl/ml_kem.h
new file mode 100644 (file)
index 0000000..2b731a5
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#ifndef OPENSSL_ML_KEM_H
+# define OPENSSL_ML_KEM_H
+# pragma once
+
+# define OSSL_ML_KEM_SHARED_SECRET_BYTES    32
+
+# define OSSL_ML_KEM_512_BITS               512
+# define OSSL_ML_KEM_512_SECURITY_BITS      128
+# define OSSL_ML_KEM_512_CIPHERTEXT_BYTES   768
+# define OSSL_ML_KEM_512_PUBLIC_KEY_BYTES   800
+
+# define OSSL_ML_KEM_768_BITS               768
+# define OSSL_ML_KEM_768_SECURITY_BITS      192
+# define OSSL_ML_KEM_768_CIPHERTEXT_BYTES   1088
+# define OSSL_ML_KEM_768_PUBLIC_KEY_BYTES   1184
+
+# define OSSL_ML_KEM_1024_BITS              1024
+# define OSSL_ML_KEM_1024_SECURITY_BITS     256
+# define OSSL_ML_KEM_1024_CIPHERTEXT_BYTES  1568
+# define OSSL_ML_KEM_1024_PUBLIC_KEY_BYTES  1568
+
+#endif
index 273093b684564ed41a304a3839f9c614199a075e..3c2882601db98f5e094dfc35fa09a9b7454ccb72 100644 (file)
@@ -102,9 +102,9 @@ static const TLS_GROUP_CONSTANTS group_list[] = {
     { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1, 0 },
     { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1, 0 },
 #ifndef OPENSSL_NO_ML_KEM
-    { OSSL_TLS_GROUP_ID_mlkem512, ML_KEM_512_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 },
-    { OSSL_TLS_GROUP_ID_mlkem768, ML_KEM_768_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 },
-    { OSSL_TLS_GROUP_ID_mlkem1024, ML_KEM_1024_RNGSEC, TLS1_3_VERSION, 0, -1, -1, 1 },
+    { OSSL_TLS_GROUP_ID_mlkem512, ML_KEM_512_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 },
+    { OSSL_TLS_GROUP_ID_mlkem768, ML_KEM_768_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 },
+    { OSSL_TLS_GROUP_ID_mlkem1024, ML_KEM_1024_SECBITS, TLS1_3_VERSION, 0, -1, -1, 1 },
 #endif
 };
 
index 28d7bb1104fcc8c9b09e8eec1dd2def7229aab29..cda7cec963219d14fc287568388cbb11262e275b 100644 (file)
@@ -442,7 +442,8 @@ static void *ml_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg)
     ML_KEM_KEY *key;
     uint8_t *nopub = NULL;
     uint8_t *seed = gctx->seed;
-    size_t slen = seed == NULL ? 0 : ML_KEM_SEED_BYTES;
+    uint8_t *d = seed != NULL ? seed : NULL;
+    uint8_t *z = seed != NULL ? seed + ML_KEM_RANDOM_BYTES : NULL;
     int genok = 0;
 
     if (gctx == NULL
@@ -454,9 +455,7 @@ static void *ml_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg)
     if ((gctx->selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0)
         return key;
 
-    genok = seed != NULL
-        ? ossl_ml_kem_genkey_seed(seed, slen, nopub, 0, key)
-        : ossl_ml_kem_genkey_rand(NULL, slen, nopub, 0, key);
+    genok = ossl_ml_kem_genkey(d, z, nopub, 0, key);
 
     /* Erase the single-use seed */
     if (seed != NULL)
@@ -496,12 +495,12 @@ typedef void (*func_ptr_t)(void);
     static void *ml_kem_##bits##_new(void *provctx) \
     { \
         return ml_kem_new(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), \
-                          NULL, ML_KEM_##bits); \
+                          NULL, ML_KEM_##bits##_VARIANT); \
     } \
     static void *ml_kem_##bits##_gen_init(void *provctx, int selection, \
                                           const OSSL_PARAM params[]) \
     { \
-        return ml_kem_gen_init(provctx, selection, params, ML_KEM_##bits); \
+        return ml_kem_gen_init(provctx, selection, params, ML_KEM_##bits##_VARIANT); \
     } \
     const OSSL_DISPATCH ossl_ml_kem_##bits##_keymgmt_functions[] = { \
         { OSSL_FUNC_KEYMGMT_NEW, (func_ptr_t) ml_kem_##bits##_new }, \
index 349755b3784a63c5c248374d6741f7d65f543bc1..1ffeb6fc2092509bdfe6074957e9b6dfa4ce21ae 100644 (file)
@@ -230,7 +230,7 @@ static int test_non_derandomised_ml_kem(void)
     if (!TEST_ptr(sha256 = EVP_MD_fetch(NULL, "sha256", NULL)))
         return 0;
 
-    for (i = ML_KEM_512; i < ML_KEM_1024; ++i) {
+    for (i = ML_KEM_512_VARIANT; i < ML_KEM_1024_VARIANT; ++i) {
         const ML_KEM_VINFO *v;
         OSSL_PARAM params[3], *p;
         uint8_t hash[32];
index 1505209d3ace1b9392f92f717d817842148e8475..54440313527d805a318bd0d008b8032324a45899 100644 (file)
@@ -107,7 +107,7 @@ static int sanity_test(void)
 
     decap_entropy = ml_kem_public_entropy + ML_KEM_RANDOM_BYTES;
 
-    for (i = ML_KEM_512; i < ML_KEM_1024; ++i) {
+    for (i = ML_KEM_512_VARIANT; i < ML_KEM_1024_VARIANT; ++i) {
         OSSL_PARAM params[3];
         uint8_t hash[32];
         uint8_t shared_secret[ML_KEM_SHARED_SECRET_BYTES];
@@ -145,8 +145,8 @@ static int sanity_test(void)
 
         ret2 = -2;
         /* Generate a private key */
-        if (!ossl_ml_kem_genkey_rand(NULL, 0, encoded_public_key,
-                                     v->pubkey_bytes, private_key))
+        if (!ossl_ml_kem_genkey(NULL, NULL, encoded_public_key,
+                                v->pubkey_bytes, private_key))
             goto done;
 
         /* Check that no more entropy is available! */