]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Add ML-KEM-768 implementation
authorMichael Baentsch <57787676+baentsch@users.noreply.github.com>
Mon, 11 Nov 2024 08:08:06 +0000 (09:08 +0100)
committerTomas Mraz <tomas@openssl.org>
Fri, 14 Feb 2025 09:47:46 +0000 (10:47 +0100)
Based on code from BoringSSL covered under Google CCLA
Original code at https://boringssl.googlesource.com/boringssl/+/HEAD/crypto/mlkem

- VSCode automatic formatting (andrewd@openssl.org)
- Just do some basic formatting to make diffs easier to read later: convert
  from 2 to 4 spaces, add newlines after function declarations, and move
  function open curly brace to new line (andrewd@openssl.org)
- Move variable init to beginning of each function (andrewd@openssl.org)
- Replace CBB API
- Fixing up constants and parameter lists
- Replace BORINGSSL_keccak calls with EVP calls
- Added library symbols and low-level test case
- Switch boringssl constant time routines for OpenSSL ones
- Data type assertion and negative test added
- Moved mlkem.h to include/crypto
- Changed function naming to be in line with ossl convention
- Remove Google license terms based on CCLA
- Add constant_time_lt_32
- Convert asserts to ossl_asserts where possible
- Add bssl keccak, pubK recreation, formatting
- Add provider interface to utilize mlkem768 code enabling TLS1.3 use
- Revert to OpenSSL DigestXOF
- Use EVP_MD_xof() to determine digest finalisation (pauli@openssl.org)
- Change APIs to return error codes; reference new IANA number; move static asserts
  to one place
- Remove boringssl keccak for good
- Fix coding style and return value checks
- ANSI C compatibility changes
- Remove static cache objects
- All internal retval functions used leading to some new retval functions

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/25848)

19 files changed:
Configure
crypto/build.info
crypto/mlkem/build.info [new file with mode: 0644]
crypto/mlkem/mlkem768.c [new file with mode: 0644]
include/crypto/mlkem.h [new file with mode: 0644]
include/internal/constant_time.h
include/internal/tlsgroups.h
providers/common/capabilities.c
providers/defltprov.c
providers/implementations/include/prov/implementations.h
providers/implementations/include/prov/mlkem.h [new file with mode: 0644]
providers/implementations/include/prov/names.h
providers/implementations/kem/build.info
providers/implementations/kem/ml_kem.c [new file with mode: 0644]
providers/implementations/keymgmt/build.info
providers/implementations/keymgmt/mlkem_kmgmt.c [new file with mode: 0644]
test/build.info
test/evp_extra_test.c
test/mlkem_internal_test.c [new file with mode: 0644]

index 98ad2dc82483eef73852982905d780aeaeb5b20f..92fd97fd2ea38796ef4af4efe72ae2c9b883a131 100755 (executable)
--- a/Configure
+++ b/Configure
@@ -487,6 +487,7 @@ my @disablables = (
     "md4",
     "mdc2",
     "ml-dsa",
+    "mlkem",
     "module",
     "msan",
     "multiblock",
index e476b678da347c8c6871e1e595960afd206402cd..72d5305616bb305692ace2984a7255d5dd39dd40 100644 (file)
@@ -2,7 +2,7 @@
 # there for further explanations.
 SUBDIRS=objects buffer bio stack lhash hashtable rand evp asn1 pem x509 conf \
         txt_db pkcs7 pkcs12 ui kdf store property \
-        md2 md4 md5 sha mdc2 hmac ripemd whrlpool poly1305 \
+        md2 md4 md5 sha mdc2 mlkem hmac ripemd whrlpool poly1305 \
         siphash sm3 des aes rc2 rc4 rc5 idea aria bf cast camellia \
         seed sm4 chacha modes bn ec rsa dsa dh sm2 dso engine \
         err comp http ocsp cms ts srp cmac ct async ess crmf cmp encode_decode \
diff --git a/crypto/mlkem/build.info b/crypto/mlkem/build.info
new file mode 100644 (file)
index 0000000..5d4ebe3
--- /dev/null
@@ -0,0 +1,3 @@
+LIBS = ../../libcrypto
+
+SOURCE[../../libcrypto] = mlkem768.c
diff --git a/crypto/mlkem/mlkem768.c b/crypto/mlkem/mlkem768.c
new file mode 100644 (file)
index 0000000..8e038f4
--- /dev/null
@@ -0,0 +1,1197 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+/* Copyright (c) 2024, Google Inc. */
+
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <assert.h>
+#include <openssl/rand.h>
+#include <openssl/err.h>
+#include <crypto/mlkem.h>
+#include <internal/sha3.h>
+#include <internal/constant_time.h>
+#include <internal/common.h>
+#ifndef NDEBUG
+# include <stdio.h>
+#endif
+
+#ifndef OPENSSL_NO_MLKEM
+
+/* Constants that are common across all sizes. */
+# define DEGREE 256
+static const size_t kBarrettMultiplier = 5039;
+static const unsigned kBarrettShift = 24;
+static const uint16_t kPrime = 3329;
+static const int kLog2Prime = 12;
+static const uint16_t kHalfPrime = (/* kPrime= */ 3329 - 1) / 2;
+
+/*
+ * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
+ * root of unity.
+ */
+static const uint16_t kInverseDegree = 3303;
+
+/* Rank-specific constants. */
+# define RANK768 3
+static const int kDU768 = 10;
+static const int kDV768 = 4;
+# define RANK1024 4
+static const int kDU1024 = 11;
+static const int kDV1024 = 5;
+
+static ossl_inline size_t compressed_vector_size(int rank)
+{
+    return (rank == RANK768 ? kDU768 : kDU1024) * (size_t)rank *
+        DEGREE / 8;
+}
+
+static ossl_inline size_t ciphertext_size(int rank)
+{
+    return compressed_vector_size(rank) +
+        (rank == RANK768 ? kDV768 : kDV1024) * DEGREE / 8;
+}
+
+typedef struct scalar {
+    /* On every function entry and exit, 0 <= c < kPrime. */
+    uint16_t c[DEGREE];
+} scalar;
+
+/* TODO(ML-KEM): possibly rename vector768 to allow for other algs */
+typedef struct vector {
+    scalar v[RANK768];
+} vector;
+
+/* TODO(ML-KEM): possibly rename matrix768 to allow for other algs */
+typedef struct matrix {
+    scalar v[RANK768][RANK768];
+} matrix;
+
+typedef struct public_key_RANK768 {
+    vector t;
+    uint8_t rho[32];
+    uint8_t public_key_hash[32];
+    matrix m;
+} public_key_RANK768;
+
+typedef struct private_key_RANK768 {
+    struct public_key_RANK768 pub;
+    vector s;
+    uint8_t fo_failure_secret[32];
+} private_key_RANK768;
+
+static ossl_inline size_t encoded_vector_size(int rank)
+{
+    return (kLog2Prime * DEGREE / 8) * (size_t)rank;
+}
+
+static ossl_inline size_t encoded_public_key_size(int rank)
+{
+    return encoded_vector_size(rank) + /* sizeof(rho)= */ 32;
+}
+
+/*
+ * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy
+ * necessary to encapsulate a secret. The entropy will be leaked to the
+ * decapsulating party.
+ */
+# define MLKEM_ENCAP_ENTROPY 32
+
+/* MD&XOF handles */
+
+/* Cache mgmt as per https://github.com/openssl/private/issues/700 */
+
+ossl_mlkem_ctx *ossl_mlkem_newctx(OSSL_LIB_CTX *libctx, const char *properties)
+{
+    ossl_mlkem_ctx *nctx = OPENSSL_zalloc(sizeof(ossl_mlkem_ctx));
+
+    /* replacing static asserts: */
+    if (nctx == NULL
+        || (OSSL_MLKEM768_SHARED_SECRET_BYTES != 32)
+        || (sizeof(unsigned int) < sizeof (uint32_t))
+        || (sizeof(struct ossl_mlkem768_public_key) <
+            sizeof(struct public_key_RANK768))
+        || (sizeof(struct ossl_mlkem768_private_key) <
+            sizeof(struct private_key_RANK768))
+        || (encoded_public_key_size(RANK768) != OSSL_MLKEM768_PUBLIC_KEY_BYTES)
+        || (encoded_public_key_size(RANK1024) != OSSL_MLKEM1024_PUBLIC_KEY_BYTES)
+        || (ciphertext_size(RANK768) != OSSL_MLKEM768_CIPHERTEXT_BYTES)
+        || (ciphertext_size(RANK1024) != OSSL_MLKEM1024_CIPHERTEXT_BYTES))
+        goto err;
+
+    nctx->shake128_cache = EVP_MD_fetch(libctx, "SHAKE128", properties);
+    nctx->shake256_cache = EVP_MD_fetch(libctx, "SHAKE256", properties);
+    nctx->sha3_256_cache = EVP_MD_fetch(libctx, "SHA3-256", properties);
+    nctx->sha3_512_cache = EVP_MD_fetch(libctx, "SHA3-512", properties);
+    nctx->libctx = libctx;
+    if (properties != NULL)
+        if ((nctx->properties = OPENSSL_strdup(properties)) == NULL)
+            goto err;
+    if (nctx->shake128_cache == NULL || nctx->shake256_cache == NULL ||
+        nctx->sha3_256_cache == NULL || nctx->sha3_512_cache == NULL)
+        goto err;
+    return nctx;
+
+err:
+    ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR);
+    ossl_mlkem_ctx_free(nctx);
+    return NULL;
+}
+
+void ossl_mlkem_ctx_free(ossl_mlkem_ctx *ctx)
+{
+    if (ctx != NULL) {
+        EVP_MD_free(ctx->shake128_cache);
+        EVP_MD_free(ctx->shake256_cache);
+        EVP_MD_free(ctx->sha3_256_cache);
+        EVP_MD_free(ctx->sha3_512_cache);
+        OPENSSL_free(ctx->properties);
+    }
+    OPENSSL_free(ctx);
+}
+
+/*
+ * single_keccak hashes |in_len| bytes from |in| and writes |out_len| bytes
+ * of output to |out|. If the |md| specifies a fixed-output function, like
+ * SHA3-256, then |out_len| must be the correct length for that function.
+ */
+static int single_keccak(uint8_t *out, size_t out_len,
+                         const uint8_t *in, size_t in_len,
+                         EVP_MD *md)
+{
+    EVP_MD_CTX *mdctx;
+    int ret = 0;
+
+    mdctx = EVP_MD_CTX_new();
+    if (mdctx == NULL
+        || !EVP_DigestInit_ex(mdctx, md, NULL)
+        || !EVP_DigestUpdate(mdctx, in, in_len))
+        return 0;
+
+    if (EVP_MD_xof(md))
+        ret = EVP_DigestFinalXOF(mdctx, out, out_len);
+    else
+        ret = EVP_DigestFinal_ex(mdctx, out, NULL);
+    EVP_MD_CTX_free(mdctx);
+
+    return ret;
+}
+
+/* TODO(ML-KEM) revisit utility of this function/remove eventually */
+static void print_hex(const uint8_t *data, int len, const char *msg)
+{
+# ifndef NDEBUG
+    if (msg)
+        printf("%s: \n", msg);
+    BIO_dump_fp(stdout, data, len);
+    fflush(0);
+# endif
+}
+
+/*
+ * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy
+ * necessary to encapsulate a secret. The entropy will be leaked to the
+ * decapsulating party.
+ */
+# define MLKEM_ENCAP_ENTROPY 32
+
+/* See https://csrc.nist.gov/pubs/fips/203/final */
+static int prf(uint8_t *out, size_t out_len, const uint8_t in[33],
+               ossl_mlkem_ctx *mlkem_ctx)
+{
+    return single_keccak(out, out_len, in, 33, mlkem_ctx->shake256_cache);
+}
+
+/*
+ * Section 4.1
+ * uint8_t out[32]
+ */
+static int hash_h(uint8_t *out, const uint8_t *in, size_t len,
+                  ossl_mlkem_ctx *mlkem_ctx)
+{
+    return single_keccak(out, 32, in, len, mlkem_ctx->sha3_256_cache);
+}
+
+/* uint8_t out[64] */
+static int hash_g(uint8_t *out, const uint8_t *in, size_t len,
+                  ossl_mlkem_ctx *mlkem_ctx)
+{
+    return single_keccak(out, 64, in, len, mlkem_ctx->sha3_512_cache);
+}
+
+/*
+ * This is called `J` in the spec.
+ * uint8_t out[ossl_mlkem768_SHARED_SECRET_BYTES],
+ * const uint8_t failure_secret[32]
+ */
+static int kdf(uint8_t *out,
+               const uint8_t *failure_secret, const uint8_t *ciphertext,
+               size_t ciphertext_len,
+               ossl_mlkem_ctx *mlkem_ctx)
+{
+    EVP_MD_CTX *mdctx;
+    int ret = 0;
+
+    mdctx = EVP_MD_CTX_new();
+    if (mdctx == NULL
+        || !EVP_DigestInit_ex(mdctx, mlkem_ctx->shake256_cache, NULL)
+        || !EVP_DigestUpdate(mdctx, failure_secret, 32)
+        || !EVP_DigestUpdate(mdctx, ciphertext, ciphertext_len)
+        || !EVP_DigestFinalXOF(mdctx, out, OSSL_MLKEM768_SHARED_SECRET_BYTES))
+        goto end;
+
+    ret = 1;
+
+end:
+    EVP_MD_CTX_free(mdctx);
+    return ret;
+}
+
+/*
+ * This bit of Python will be referenced in some of the following comments:
+ *
+ * p = 3329
+ *
+ * def bitreverse(i):
+ *     ret = 0
+ *     for n in range(7):
+ *         bit = i & 1
+ *         ret <<= 1
+ *         ret |= bit
+ *         i >>= 1
+ *     return ret
+ * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
+ */
+
+static const uint16_t kNTTRoots[128] = {
+    1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
+    2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
+    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
+    1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
+    2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
+    2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
+    1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
+    1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
+    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
+    2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
+    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
+};
+
+/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
+static const uint16_t kInverseNTTRoots[128] = {
+    1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
+    2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
+    1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
+    2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
+    1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
+    1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
+    2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
+    2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
+    2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
+    1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
+    2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
+};
+
+/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
+static const uint16_t kModRoots[128] = {
+    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
+    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
+    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
+    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
+    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
+    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
+    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
+    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
+    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
+    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
+    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
+};
+
+/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
+static uint16_t reduce_once(uint16_t x)
+{
+    const uint16_t subtracted = x - kPrime;
+    uint16_t mask = 0u - (subtracted >> 15);
+
+    assert(x < 2 * kPrime);
+    /*
+     * On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of
+     * ML-KEM overall and Clang still produces constant-time code using `csel`. On
+     * other platforms & compilers on godbolt that we care about, this code also
+     * produces constant-time output.
+     */
+    return (mask & x) | (~mask & subtracted);
+}
+
+/*
+ * constant time reduce x mod kPrime using Barrett reduction. x must be less
+ * than kPrime + 2xkPrime^2.
+ */
+static uint16_t reduce(uint32_t x)
+{
+    uint64_t product = (uint64_t)x * kBarrettMultiplier;
+    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
+    uint32_t remainder = x - quotient * kPrime;
+
+    assert(x < kPrime + 2u * kPrime * kPrime);
+    return reduce_once(remainder);
+}
+
+static void scalar_zero(scalar *out)
+{
+    memset(out, 0, sizeof(*out));
+}
+
+static void vector_zero(vector *out)
+{
+    memset(out->v, 0, sizeof(scalar) * RANK768);
+}
+
+/*
+ * In place number theoretic transform of a given scalar.
+ * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
+ * transform leaves off the last iteration of the usual FFT code, with the 128
+ * relevant roots of unity being stored in |kNTTRoots|. This means the output
+ * should be seen as 128 elements in GF(3329^2), with the coefficients of the
+ * elements being consecutive entries in |s->c|.
+ */
+static void scalar_ntt(scalar *s)
+{
+    int offset = DEGREE;
+    int k, step, i, j;
+    uint32_t step_root;
+    uint16_t odd, even;
+
+    /*
+     * `int` is used here because using `size_t` throughout caused a ~5% slowdown
+     * with Clang 14 on Aarch64.
+     */
+    for (step = 1; step < DEGREE / 2; step <<= 1) {
+        offset >>= 1;
+        k = 0;
+        for (i = 0; i < step; i++) {
+            step_root = kNTTRoots[i + step];
+            for (j = k; j < k + offset; j++) {
+                odd = reduce(step_root * s->c[j + offset]);
+                even = s->c[j];
+                s->c[j] = reduce_once(odd + even);
+                s->c[j + offset] = reduce_once(even - odd + kPrime);
+            }
+            k += 2 * offset;
+        }
+    }
+}
+
+static void vector_ntt(vector *a)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_ntt(&a->v[i]);
+}
+
+/*
+ * In place inverse number theoretic transform of a given scalar, with pairs of
+ * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
+ * number theoretic transform, this leaves off the first step of the normal iFFT
+ * to account for the fact that 3329 does not have a 512th root of unity, using
+ * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
+ */
+static void scalar_inverse_ntt(scalar *s)
+{
+    int step = DEGREE / 2;
+    int offset, k, i, j;
+    uint32_t step_root;
+    uint16_t odd, even;
+
+    /*
+     * `int` is used here because using `size_t` throughout caused a ~5% slowdown
+     * with Clang 14 on Aarch64.
+     */
+    for (offset = 2; offset < DEGREE; offset <<= 1) {
+        step >>= 1;
+        k = 0;
+        for (i = 0; i < step; i++) {
+            step_root = kInverseNTTRoots[i + step];
+            for (j = k; j < k + offset; j++) {
+                odd = s->c[j + offset];
+                even = s->c[j];
+                s->c[j] = reduce_once(odd + even);
+                s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
+            }
+            k += 2 * offset;
+        }
+    }
+    for (i = 0; i < DEGREE; i++)
+        s->c[i] = reduce(s->c[i] * kInverseDegree);
+}
+
+static void vector_inverse_ntt(vector *a)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_inverse_ntt(&a->v[i]);
+}
+
+static void scalar_add(scalar *lhs, const scalar *rhs)
+{
+    int i;
+
+    for (i = 0; i < DEGREE; i++)
+        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
+}
+
+static void scalar_sub(scalar *lhs, const scalar *rhs)
+{
+    int i;
+
+    for (i = 0; i < DEGREE; i++)
+        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
+}
+
+/*
+ * Multiplying two scalars in the number theoretically transformed state. Since
+ * 3329 does not have a 512th root of unity, this means we have to interpret
+ * the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
+ * - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
+ * stored in the precomputed |kModRoots| table. Note that our Barrett transform
+ * only allows us to multipy two reduced numbers together, so we need some
+ * intermediate reduction steps, even if an uint64_t could hold 3 multiplied
+ * numbers.
+ */
+static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
+{
+    int i;
+    uint32_t real_real, img_img, real_img, img_real;
+
+    for (i = 0; i < DEGREE / 2; i++) {
+        real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
+        img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
+        real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
+        img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
+        out->c[2 * i] =
+            reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
+        out->c[2 * i + 1] = reduce(img_real + real_img);
+    }
+}
+
+static void vector_add(vector *lhs, const vector *rhs)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_add(&lhs->v[i], &rhs->v[i]);
+}
+
+static void matrix_mult(vector *out, const matrix *m,
+                        const vector *a)
+{
+    int i, j;
+
+    vector_zero(out);
+    for (i = 0; i < RANK768; i++) {
+        for (j = 0; j < RANK768; j++) {
+            scalar product;
+
+            scalar_mult(&product, &m->v[i][j], &a->v[j]);
+            scalar_add(&out->v[i], &product);
+        }
+    }
+}
+
+static void matrix_mult_transpose(vector *out, const matrix *m,
+                                  const vector *a)
+{
+    int i, j;
+
+    vector_zero(out);
+    for (i = 0; i < RANK768; i++) {
+        for (j = 0; j < RANK768; j++) {
+            scalar product;
+
+            scalar_mult(&product, &m->v[j][i], &a->v[j]);
+            scalar_add(&out->v[i], &product);
+        }
+    }
+}
+
+static void scalar_inner_product(scalar *out, const vector *lhs,
+                                 const vector *rhs)
+{
+    int i;
+
+    scalar_zero(out);
+    for (i = 0; i < RANK768; i++) {
+        scalar product;
+
+        scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
+        scalar_add(out, &product);
+    }
+}
+
+/*
+ * Algorithm 6 from the spec. Rejection samples a Keccak stream to get
+ * uniformly distributed elements. This is used for matrix expansion and only
+ * operates on public inputs.
+ */
+static int scalar_from_keccak_vartime(scalar *out, EVP_MD_CTX *mdctx)
+{
+    int done = 0;
+    uint8_t block[168];
+    size_t i;
+    uint16_t d1, d2;
+
+    while (done < DEGREE) {
+        if (!EVP_DigestSqueeze(mdctx, block, sizeof(block)))
+            return 0;
+        for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
+            d1 = block[i] + 256 * (block[i + 1] % 16);
+            d2 = block[i + 1] / 16 + 16 * block[i + 2];
+            if (d1 < kPrime)
+                out->c[done++] = d1;
+            if (d2 < kPrime && done < DEGREE)
+                out->c[done++] = d2;
+        }
+    }
+    return 1;
+}
+
+/*
+ * Algorithm 7 from the spec, with eta fixed to two and the PRF call
+ * included. Creates binominally distributed elements by sampling 2*|eta| bits,
+ * and setting the coefficient to the count of the first bits minus the count of
+ * the second bits, resulting in a centered binomial distribution. Since eta is
+ * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
+ * and 0 with probability 3/8.
+ */
+static
+int scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
+                                                         const uint8_t input[33],
+                                                         ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t entropy[128];
+    int i;
+    uint8_t byte;
+    uint16_t value;
+
+    assert(sizeof(entropy) == 2 * /* kEta= */ 2 * DEGREE / 8);
+    if (!prf(entropy, sizeof(entropy), input, mlkem_ctx))
+        return 0;
+    for (i = 0; i < DEGREE; i += 2) {
+        byte = entropy[i / 2];
+        value = kPrime;
+        value += (byte & 1) + ((byte >> 1) & 1);
+        value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
+        out->c[i] = reduce_once(value);
+        byte >>= 4;
+        value = kPrime;
+        value += (byte & 1) + ((byte >> 1) & 1);
+        value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
+        out->c[i + 1] = reduce_once(value);
+    }
+    return 1;
+}
+
+/*
+ * Generates a secret vector by using
+ * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
+ * appending and incrementing |counter| for entry of the vector.
+ */
+static int vector_generate_secret_eta_2(vector *out, uint8_t *counter,
+                                        const uint8_t seed[32],
+                                        ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t input[33];
+    int i;
+
+    memcpy(input, seed, 32);
+    for (i = 0; i < RANK768; i++) {
+        input[32] = (*counter)++;
+        if (!scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
+                                                                  input, mlkem_ctx))
+            return 0;
+    }
+    return 1;
+}
+
+/* Expands the matrix of a seed for key generation and for encaps-CPA. */
+static int matrix_expand(matrix *out, const uint8_t rho[32],
+                         ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t input[34];
+    int i, j, ret = 0;
+    EVP_MD_CTX *mdctx = EVP_MD_CTX_new();
+
+    if (mdctx == NULL)
+        goto end;
+    memcpy(input, rho, 32);
+    for (i = 0; i < RANK768; i++) {
+        for (j = 0; j < RANK768; j++) {
+            input[32] = i;
+            input[33] = j;
+            if (!EVP_DigestInit_ex(mdctx, mlkem_ctx->shake128_cache, NULL)
+                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
+                || !scalar_from_keccak_vartime(&out->v[i][j], mdctx))
+                goto end;
+        }
+    }
+
+    ret = 1;
+end:
+    EVP_MD_CTX_free(mdctx);
+    return ret;
+}
+
+static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
+                                  0x1f, 0x3f, 0x7f, 0xff};
+
+static void scalar_encode(uint8_t *out, const scalar *s, int bits)
+{
+    uint8_t out_byte = 0;
+    int out_byte_bits = 0;
+    int i, element_bits_done, chunk_bits, out_bits_remaining;
+    uint16_t element;
+
+    assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
+    for (i = 0; i < DEGREE; i++) {
+        element = s->c[i];
+        element_bits_done = 0;
+        while (element_bits_done < bits) {
+            chunk_bits = bits - element_bits_done;
+            out_bits_remaining = 8 - out_byte_bits;
+            if (chunk_bits >= out_bits_remaining) {
+                chunk_bits = out_bits_remaining;
+                out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
+                *out = out_byte;
+                out++;
+                out_byte_bits = 0;
+                out_byte = 0;
+            } else {
+                out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
+                out_byte_bits += chunk_bits;
+            }
+            element_bits_done += chunk_bits;
+            element >>= chunk_bits;
+        }
+    }
+    if (out_byte_bits > 0)
+        *out = out_byte;
+}
+
+/*
+ * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
+ * uint8_t out[32]
+ */
+static void scalar_encode_1(uint8_t *out, const scalar *s)
+{
+    int i, j;
+    uint8_t out_byte;
+
+    for (i = 0; i < DEGREE; i += 8) {
+        out_byte = 0;
+        for (j = 0; j < 8; j++)
+            out_byte |= (s->c[i + j] & 1) << j;
+        *out = out_byte;
+        out++;
+    }
+}
+
+/*
+ * Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
+ * (DEGREE) is divisible by 8, the individual vector entries will always fill a
+ * whole number of bytes, so we do not need to worry about bit packing here.
+ */
+static void vector_encode(uint8_t *out, const vector *a, int bits)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
+}
+
+/*
+ * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
+ * |out|. It returns one on success and zero if any parsed value is >=
+ * |kPrime|.
+ */
+static int scalar_decode(scalar *out, const uint8_t *in, int bits)
+{
+    uint8_t in_byte = 0;
+    int in_byte_bits_left = 0;
+    int i, element_bits_done, chunk_bits;
+    uint16_t element;
+
+    if (!ossl_assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1))
+        return 0;
+    for (i = 0; i < DEGREE; i++) {
+        element = 0;
+        element_bits_done = 0;
+        while (element_bits_done < bits) {
+            if (in_byte_bits_left == 0) {
+                in_byte = *in;
+                in++;
+                in_byte_bits_left = 8;
+            }
+            chunk_bits = bits - element_bits_done;
+            if (chunk_bits > in_byte_bits_left)
+                chunk_bits = in_byte_bits_left;
+            element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
+            in_byte_bits_left -= chunk_bits;
+            in_byte >>= chunk_bits;
+            element_bits_done += chunk_bits;
+        }
+        if (element >= kPrime)
+            return 0;
+        out->c[i] = element;
+    }
+    return 1;
+}
+
+/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
+static void scalar_decode_1(scalar *out, const uint8_t in[32])
+{
+    int i, j;
+    uint8_t in_byte;
+
+    for (i = 0; i < DEGREE; i += 8) {
+        in_byte = *in;
+        in++;
+        for (j = 0; j < 8; j++) {
+            out->c[i + j] = in_byte & 1;
+            in_byte >>= 1;
+        }
+    }
+}
+
+/*
+ * Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
+ * success or zero if any parsed value is >= |kPrime|.
+ */
+static int vector_decode(vector *out, const uint8_t *in, int bits)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++) {
+        if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits))
+            return 0;
+    }
+    return 1;
+}
+
+/*
+ * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
+ * numbers close to each other together. The formula used is
+ * round(2^|bits|/kPrime*x) mod 2^|bits|.
+ * Uses Barrett reduction to achieve constant time. Since we need both the
+ * remainder (for rounding) and the quotient (as the result), we cannot use
+ * |reduce| here, but need to do the Barrett reduction directly.
+ */
+static uint16_t compress(uint16_t x, int bits)
+{
+    uint32_t shifted = (uint32_t)x << bits;
+    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
+    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
+    uint32_t remainder = shifted - quotient * kPrime;
+
+    /*
+     * Adjust the quotient to round correctly:
+     *   0 <= remainder <= kHalfPrime round to 0
+     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
+     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
+     */
+    assert(remainder < 2u * kPrime);
+    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
+    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
+    return quotient & ((1 << bits) - 1);
+}
+
+/*
+ * Decompresses |x| by using an equi-distant representative. The formula is
+ * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
+ * implement this logic using only bit operations.
+ */
+static uint16_t decompress(uint16_t x, int bits)
+{
+    uint32_t product = (uint32_t)x * kPrime;
+    uint32_t power = 1 << bits;
+    /* This is |product| % power, since |power| is a power of 2. */
+    uint32_t remainder = product & (power - 1);
+    /* This is |product| / power, since |power| is a power of 2. */
+    uint32_t lower = product >> bits;
+
+    /*
+     * The rounding logic works since the first half of numbers mod |power| have a
+     * 0 as first bit, and the second half has a 1 as first bit, since |power| is
+     * a power of 2. As a 12 bit number, |remainder| is always positive, so we
+     * will shift in 0s for a right shift.
+     */
+    return lower + (remainder >> (bits - 1));
+}
+
+static void scalar_compress(scalar *s, int bits)
+{
+    int i;
+
+    for (i = 0; i < DEGREE; i++)
+        s->c[i] = compress(s->c[i], bits);
+}
+
+static void scalar_decompress(scalar *s, int bits)
+{
+    int i;
+
+    for (i = 0; i < DEGREE; i++)
+        s->c[i] = decompress(s->c[i], bits);
+}
+
+static void vector_compress(vector *a, int bits)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_compress(&a->v[i], bits);
+}
+
+static void vector_decompress(vector *a, int bits)
+{
+    int i;
+
+    for (i = 0; i < RANK768; i++)
+        scalar_decompress(&a->v[i], bits);
+}
+
+static
+public_key_RANK768 *public_key_768_from_external(const ossl_mlkem768_public_key *external)
+{
+    return (struct public_key_RANK768 *)external;
+}
+
+static
+private_key_RANK768 *private_key_768_from_external(const ossl_mlkem768_private_key *external)
+{
+    return (struct private_key_RANK768 *)external;
+}
+
+static int mlkem_marshal_public_key(uint8_t *out,
+                                    const struct public_key_RANK768 *pub)
+{
+    /*
+     * replace CBB logic with straight copy to out and memcpy of rho at tail end
+     * TODO(ML-KEM): Check this is OK to protect incorrect buffer(sizes) passed
+     * possibly use WPACKET?
+     */
+    vector_encode(out, &pub->t, kLog2Prime);
+    memcpy(out + encoded_vector_size(RANK768), pub->rho, sizeof(pub->rho));
+    return 1;
+}
+
+int ossl_mlkem768_recreate_public_key(const uint8_t *encoded_public_key,
+                                      ossl_mlkem768_public_key *ext_pub,
+                                      ossl_mlkem_ctx *mlkem_ctx)
+{
+    struct public_key_RANK768 *pub = public_key_768_from_external(ext_pub);
+
+    print_hex(encoded_public_key, OSSL_MLKEM768_PUBLIC_KEY_BYTES, "Encoded key");
+    if (!vector_decode(&pub->t, encoded_public_key, kLog2Prime))
+        return 0;
+    memcpy(pub->rho, encoded_public_key + encoded_vector_size(RANK768), sizeof(pub->rho));
+    if (!matrix_expand(&pub->m, pub->rho, mlkem_ctx)
+        || (!hash_h(pub->public_key_hash, encoded_public_key,
+                    encoded_public_key_size(RANK768), mlkem_ctx)))
+        return 0;
+    print_hex((uint8_t *)pub, sizeof(public_key_RANK768), "recreated PK");
+    return 1;
+}
+
+static int mlkem_generate_key_external_seed(uint8_t *out_encoded_public_key,
+                                            private_key_RANK768 *priv,
+                                            const uint8_t *seed,
+                                            ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t augmented_seed[33];
+    uint8_t hashed[64];
+    const uint8_t *const rho = hashed;
+    const uint8_t *const sigma = hashed + 32;
+    uint8_t counter = 0;
+    vector error;
+
+    if (mlkem_ctx == NULL)
+        return 0;
+
+    memcpy(augmented_seed, seed, 32);
+    augmented_seed[32] = RANK768;
+    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mlkem_ctx))
+        return 0;
+    memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
+    if (!matrix_expand(&priv->pub.m, rho, mlkem_ctx)
+        || (!vector_generate_secret_eta_2(&priv->s, &counter, sigma, mlkem_ctx)))
+        return 0;
+    vector_ntt(&priv->s);
+    if (!vector_generate_secret_eta_2(&error, &counter, sigma, mlkem_ctx))
+        return 0;
+    vector_ntt(&error);
+    matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
+    vector_add(&priv->pub.t, &error);
+    if (!mlkem_marshal_public_key(out_encoded_public_key, &priv->pub)
+        || !hash_h(priv->pub.public_key_hash, out_encoded_public_key,
+                   encoded_public_key_size(RANK768), mlkem_ctx))
+        return 0;
+    memcpy(priv->fo_failure_secret, seed + 32, 32);
+    return 1;
+}
+
+static
+int ossl_mlkem768_generate_key_external_seed(uint8_t *out_encoded_public_key,
+                                             ossl_mlkem768_private_key *out_private_key,
+                                             const uint8_t *seed,
+                                             ossl_mlkem_ctx *mlkem_ctx)
+{
+    private_key_RANK768 *priv = NULL;
+
+    priv = private_key_768_from_external(out_private_key);
+    return mlkem_generate_key_external_seed(out_encoded_public_key, priv, seed, mlkem_ctx);
+}
+
+int ossl_mlkem768_generate_key(uint8_t *out_encoded_public_key,
+                               uint8_t *optional_out_seed,
+                               ossl_mlkem768_private_key *out_private_key,
+                               ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t seed[MLKEM_SEED_BYTES];
+
+    if (mlkem_ctx == NULL)
+        return 0;
+
+    /* TODO(ML-KEM): Review requested randomness strength */
+    if (RAND_priv_bytes_ex(mlkem_ctx->libctx, seed, sizeof(seed), 256) == 1) {
+        if (optional_out_seed)
+            memcpy(optional_out_seed, seed, sizeof(seed));
+        return ossl_mlkem768_generate_key_external_seed(out_encoded_public_key,
+                                                        out_private_key,
+                                                        seed, mlkem_ctx);
+    }
+    return 0;
+}
+
+int ossl_mlkem768_private_key_from_seed(ossl_mlkem768_private_key *out_private_key,
+                                        const uint8_t *seed, size_t seed_len,
+                                        ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t public_key_bytes[OSSL_MLKEM768_PUBLIC_KEY_BYTES];
+
+    if (seed_len != MLKEM_SEED_BYTES)
+        return 0;
+    ossl_mlkem768_generate_key_external_seed(public_key_bytes, out_private_key,
+                                             seed, mlkem_ctx);
+    return 1;
+}
+
+int ossl_mlkem768_public_from_private(ossl_mlkem768_public_key *out_public_key,
+                                      const ossl_mlkem768_private_key *private_key)
+{
+    struct public_key_RANK768 *const pub = public_key_768_from_external(out_public_key);
+    const struct private_key_RANK768 *const priv =
+        private_key_768_from_external(private_key);
+
+    if (priv == NULL)
+        return 0;
+    *pub = priv->pub;
+    return 1;
+}
+
+/*
+ * Encrypts a message with given randomness to
+ * the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
+ * would not result in a CCA secure scheme, since lattice schemes are vulnerable
+ * to decryption failure oracles.
+ */
+static int encrypt_cpa(uint8_t *out, const struct public_key_RANK768 *pub,
+                       const uint8_t *message,
+                       const uint8_t *randomness,
+                       ossl_mlkem_ctx *mlkem_ctx)
+{
+    int du = kDU768;
+    int dv = kDV768;
+    uint8_t counter = 0;
+    vector secret, error;
+    uint8_t input[33];
+    scalar scalar_error;
+    vector u;
+    scalar v;
+    scalar expanded_message;
+
+    if (!vector_generate_secret_eta_2(&secret, &counter, randomness, mlkem_ctx))
+        return 0;
+    vector_ntt(&secret);
+    if (!vector_generate_secret_eta_2(&error, &counter, randomness, mlkem_ctx))
+        return 0;
+    memcpy(input, randomness, 32);
+    input[32] = counter;
+    if (!scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input, mlkem_ctx))
+        return 0;
+    matrix_mult(&u, &pub->m, &secret);
+    vector_inverse_ntt(&u);
+    vector_add(&u, &error);
+    scalar_inner_product(&v, &pub->t, &secret);
+    scalar_inverse_ntt(&v);
+    scalar_add(&v, &scalar_error);
+    scalar_decode_1(&expanded_message, message);
+    scalar_decompress(&expanded_message, 1);
+    scalar_add(&v, &expanded_message);
+    vector_compress(&u, du);
+    vector_encode(out, &u, du);
+    scalar_compress(&v, dv);
+    scalar_encode(out + compressed_vector_size(RANK768), &v, dv);
+    return 1;
+}
+
+/*
+ * See section 6.2.
+ * entropy[MLKEM_ENCAP_ENTROPY])
+ */
+static int mlkem_encap_external_entropy(uint8_t *out_ciphertext,
+                                        uint8_t *out_shared_secret,
+                                        const public_key_RANK768 *pub,
+                                        const uint8_t *entropy,
+                                        ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t input[64];
+    uint8_t key_and_randomness[64];
+
+    memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
+    memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
+           sizeof(input) - MLKEM_ENCAP_ENTROPY);
+    if (!hash_g(key_and_randomness, input, sizeof(input), mlkem_ctx)
+        || !encrypt_cpa(out_ciphertext, pub, entropy,
+                        key_and_randomness + 32, mlkem_ctx))
+        return 0;
+    memcpy(out_shared_secret, key_and_randomness, 32);
+    return 1;
+}
+
+/*
+ * out_ciphertext[ossl_mlkem768_CIPHERTEXT_BYTES],
+ * out_shared_secret[ossl_mlkem768_SHARED_SECRET_BYTES],
+ * entropy[MLKEM_ENCAP_ENTROPY])
+ */
+static
+int ossl_mlkem768_encap_external_entropy(uint8_t *out_ciphertext,
+                                         uint8_t *out_shared_secret,
+                                         const ossl_mlkem768_public_key *public_key,
+                                         const uint8_t *entropy,
+                                         ossl_mlkem_ctx *mlkem_ctx)
+{
+    const struct public_key_RANK768 *pub =
+        public_key_768_from_external(public_key);
+
+    return mlkem_encap_external_entropy(out_ciphertext, out_shared_secret, pub,
+                                        entropy, mlkem_ctx);
+}
+
+/* Calls |ossl_mlkem768_encap_external_entropy| with random bytes from |RAND_bytes| */
+int ossl_mlkem768_encap(uint8_t *out_ciphertext,
+                        uint8_t *out_shared_secret,
+                        const ossl_mlkem768_public_key *public_key,
+                        ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t entropy[MLKEM_ENCAP_ENTROPY];
+
+    if (mlkem_ctx == NULL)
+        return 0;
+
+    /* TODO(ML-KEM): Review requested randomness strength */
+    if (RAND_bytes_ex(mlkem_ctx->libctx, entropy, MLKEM_ENCAP_ENTROPY, 256) != 1
+        || !ossl_mlkem768_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
+                                                 entropy, mlkem_ctx))
+        return 0;
+    print_hex((uint8_t *)public_key, sizeof(ossl_mlkem768_public_key), "PK");
+    print_hex(out_shared_secret, OSSL_MLKEM768_SHARED_SECRET_BYTES, "SS2");
+    print_hex(out_ciphertext, OSSL_MLKEM768_CIPHERTEXT_BYTES, "CT2");
+    return 1;
+}
+
+static void decrypt_cpa(uint8_t *out, const struct private_key_RANK768 *priv,
+                        const uint8_t *ciphertext)
+{
+    int du = kDU768;
+    int dv = kDV768;
+    vector u;
+    scalar v, mask;
+
+    vector_decode(&u, ciphertext, du);
+    vector_decompress(&u, du);
+    vector_ntt(&u);
+    scalar_decode(&v, ciphertext + compressed_vector_size(RANK768), dv);
+    scalar_decompress(&v, dv);
+    scalar_inner_product(&mask, &priv->s, &u);
+    scalar_inverse_ntt(&mask);
+    scalar_sub(&v, &mask);
+    scalar_compress(&v, 1);
+    scalar_encode_1(out, &v);
+}
+
+/* See section 6.3 */
+static int mlkem_decap(uint8_t *out_shared_secret,
+                       const uint8_t *ciphertext,
+                       const struct private_key_RANK768 *priv,
+                       ossl_mlkem_ctx *mlkem_ctx)
+{
+    uint8_t decrypted[64];
+    uint8_t key_and_randomness[64];
+    size_t ciphertext_len = ciphertext_size(RANK768);
+    /* TODO(ML-KEM): Maximum also applicable for other algs? */
+    uint8_t expected_ciphertext[OSSL_MLKEM1024_CIPHERTEXT_BYTES];
+    uint8_t failure_key[32];
+    uint8_t mask;
+    int i;
+
+    print_hex((uint8_t *)&priv->pub, sizeof(ossl_mlkem768_public_key), "PK1");
+    print_hex(ciphertext, OSSL_MLKEM768_CIPHERTEXT_BYTES, "CT");
+    decrypt_cpa(decrypted, priv, ciphertext);
+    memcpy(decrypted + 32, priv->pub.public_key_hash,
+           sizeof(decrypted) - 32);
+    if (!hash_g(key_and_randomness, decrypted, sizeof(decrypted), mlkem_ctx))
+        return 0;
+    assert(ciphertext_len <= sizeof(expected_ciphertext));
+    encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
+                key_and_randomness + 32, mlkem_ctx);
+    kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len, mlkem_ctx);
+    mask = constant_time_eq_int_8(CRYPTO_memcmp(ciphertext,
+                                                expected_ciphertext, ciphertext_len), 0);
+    for (i = 0; i < OSSL_MLKEM768_SHARED_SECRET_BYTES; i++)
+        out_shared_secret[i] = constant_time_select_8(mask,
+                                                      key_and_randomness[i],
+                                                      failure_key[i]);
+    print_hex(out_shared_secret, OSSL_MLKEM768_SHARED_SECRET_BYTES, "SS");
+    return 1;
+}
+
+int ossl_mlkem768_decap(uint8_t *out_shared_secret,
+                        const uint8_t *ciphertext, size_t ciphertext_len,
+                        const ossl_mlkem768_private_key *private_key,
+                        ossl_mlkem_ctx *mlkem_ctx)
+{
+    const struct private_key_RANK768 *priv;
+
+    if (mlkem_ctx == NULL)
+        return 0;
+
+    if (ciphertext_len != OSSL_MLKEM768_CIPHERTEXT_BYTES) {
+        /* TODO(ML-KEM): Review requested randomness strength */
+        RAND_bytes_ex(mlkem_ctx->libctx, out_shared_secret,
+                      OSSL_MLKEM768_SHARED_SECRET_BYTES, 256);
+        return 0;
+    }
+    priv = private_key_768_from_external(private_key);
+    return mlkem_decap(out_shared_secret, ciphertext, priv, mlkem_ctx);
+}
+
+#endif /* OPENSSL_NO_MLKEM */
diff --git a/include/crypto/mlkem.h b/include/crypto/mlkem.h
new file mode 100644 (file)
index 0000000..3b4321c
--- /dev/null
@@ -0,0 +1,179 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+/* Copyright (c) 2024, Google Inc. */
+
+#ifndef OPENSSL_HEADER_MLKEM_H
+# define OPENSSL_HEADER_MLKEM_H
+
+# include <stdint.h>
+# include <openssl/e_os2.h>
+# include <crypto/evp.h>
+
+# if defined(__cplusplus)
+extern "C" {
+# endif
+
+# ifndef OPENSSL_NO_MLKEM
+
+    typedef struct ossl_mlkem_ctx {
+        EVP_MD *shake128_cache;
+        EVP_MD *shake256_cache;
+        EVP_MD *sha3_256_cache;
+        EVP_MD *sha3_512_cache;
+        OSSL_LIB_CTX *libctx;
+        char *properties;
+    } ossl_mlkem_ctx;
+
+    /* General ctx functions */
+    ossl_mlkem_ctx *ossl_mlkem_newctx(OSSL_LIB_CTX *libctx, const char *properties);
+
+    void ossl_mlkem_ctx_free(ossl_mlkem_ctx *ctx);
+
+    /*
+     * ML-KEM-768.
+     *
+     * This implements the Module-Lattice-Based Key-Encapsulation Mechanism from
+     * https://csrc.nist.gov/pubs/fips/203/final
+     */
+
+    /*
+     * ossl_mlkem768_public_key contains an ML-KEM-768 public key. The contents of this
+     * object should never leave the address space since the format is unstable.
+     */
+
+    /* TODO: Review alignment as per https://github.com/openssl/private/issues/702 */
+
+    typedef struct ossl_mlkem768_public_key {
+        union {
+            uint8_t bytes[512 * (3 + 9) + 32 + 32];
+            uint16_t alignment;
+        } opaque;
+    } ossl_mlkem768_public_key;
+
+    /*
+     * ossl_mlkem768_private_key contains an ML-KEM-768 private key. The contents of this
+     * object should never leave the address space since the format is unstable.
+     */
+    typedef struct ossl_mlkem768_private_key {
+        union {
+            uint8_t bytes[512 * (3 + 3 + 9) + 32 + 32 + 32];
+            uint16_t alignment;
+        } opaque;
+    } ossl_mlkem768_private_key;
+
+/*
+ * Parameters from FIPS 203 Section 8: Parameter Sets
+ * Reference: https://csrc.nist.gov/pubs/fips/203/final
+ */
+
+#  define OSSL_MLKEM768_SECURITY_BITS 192
+
+  /*
+   * OSSL_MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024
+   * public key.
+   */
+#  define OSSL_MLKEM1024_PUBLIC_KEY_BYTES 1568
+
+  /*
+   * OSSL_MLKEM768_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-768
+   * public key.
+   */
+#  define OSSL_MLKEM768_PUBLIC_KEY_BYTES 1184
+
+  /* MLKEM_SEED_BYTES is the number of bytes in an ML-KEM seed. */
+#  define MLKEM_SEED_BYTES 64
+
+    /*
+     * ossl_mlkem768_generate_key generates a random public/private key pair, writes the
+     * encoded public key to |out_encoded_public_key| and sets |out_private_key| to
+     * the private key. If |optional_out_seed| is not NULL then the seed used to
+     * generate the private key is written to it as well.
+     * out_encoded_public_key must be allocated to store ossl_mlkem768_PUBLIC_KEY_BYTES
+     * optional_out_seed if not NULL must be allocated to store MLKEM_SEED_BYTES
+     */
+    int ossl_mlkem768_generate_key(uint8_t *out_encoded_public_key,
+                                   uint8_t *optional_out_seed,
+                                   struct ossl_mlkem768_private_key *out_private_key,
+                                   ossl_mlkem_ctx *mlkem_ctx);
+
+    /*
+     * ossl_mlkem768_private_key_from_seed derives a private key from a seed that was
+     * generated by |ossl_mlkem768_generate_key|. It fails and returns 0 if |seed_len| is
+     * incorrect, otherwise it writes |*out_private_key| and returns 1.
+     */
+    int ossl_mlkem768_private_key_from_seed(ossl_mlkem768_private_key *out_private_key,
+                                            const uint8_t *seed,
+                                            size_t seed_len,
+                                            ossl_mlkem_ctx *mlkem_ctx);
+
+    /*
+     * ossl_mlkem768_public_from_private sets |*out_public_key| to the public key that
+     * corresponds to |private_key|. (This is faster than parsing the output of
+     * |ossl_mlkem768_generate_key| if, for some reason, you need to encapsulate to a key
+     * that was just generated.)
+     */
+    int ossl_mlkem768_public_from_private(ossl_mlkem768_public_key *out_public_key,
+                                          const ossl_mlkem768_private_key *private_key);
+
+  /* ossl_mlkem1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */
+#  define OSSL_MLKEM1024_CIPHERTEXT_BYTES 1568
+
+  /* ossl_mlkem768_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-768 ciphertext. */
+#  define OSSL_MLKEM768_CIPHERTEXT_BYTES 1088
+
+  /* ossl_mlkem768_SHARED_SECRET_BYTES is the number of bytes in an ML-KEM shared secret. */
+#  define OSSL_MLKEM768_SHARED_SECRET_BYTES 32
+
+    /*
+     * ossl_mlkem768_encap encrypts a random shared secret for |public_key|, writes the
+     * ciphertext to |out_ciphertext|, and writes the random shared secret to
+     * |out_shared_secret|.
+     * it is assumed out_ciphertext has been allocated ossl_mlkem768_CIPHERTEXT_BYTES bytes
+     * and out_shared_secret has been allocated MLKEM_SHARED_SECRET_BYTES bytes
+     */
+    int ossl_mlkem768_encap(uint8_t *out_ciphertext,
+                            uint8_t *out_shared_secret,
+                            const ossl_mlkem768_public_key *public_key,
+                            ossl_mlkem_ctx *mlkem_ctx);
+
+    /*
+     * ossl_mlkem768_decap decrypts a shared secret from |ciphertext| using |private_key|
+     * and writes it to |out_shared_secret|. If |ciphertext_len| is incorrect it
+     * returns 0, otherwise it returns 1. If |ciphertext| is invalid (but of the
+     * correct length), |out_shared_secret| is filled with a key that will always be
+     * the same for the same |ciphertext| and |private_key|, but which appears to be
+     * random unless one has access to |private_key|. These alternatives occur in
+     * constant time. Any subsequent symmetric encryption using |out_shared_secret|
+     * must use an authenticated encryption scheme in order to discover the
+     * decapsulation failure.
+     * it is assumed out_shared_secret has been allocated MLKEM_SHARED_SECRET_BYTES bytes
+     */
+    int ossl_mlkem768_decap(uint8_t *out_shared_secret,
+                            const uint8_t *ciphertext, size_t ciphertext_len,
+                            const ossl_mlkem768_private_key *private_key,
+                            ossl_mlkem_ctx *mlkem_ctx);
+
+    /*
+     * ossl_mlkem768_recreate_public_key recreates a fully formed ossl_mlkem768_public_key
+     * from an input |encoded_public_key| of size ossl_mlkem768_PUBLIC_KEY_BYTES.
+     * |pub| is expected to point to an allocated memory area of
+     * sizeof(ossl_mlkem768_public_key)
+     */
+    int ossl_mlkem768_recreate_public_key(const uint8_t *encoded_public_key,
+                                          ossl_mlkem768_public_key *pub,
+                                          ossl_mlkem_ctx *mlkem_ctx);
+
+# endif /* OPENSSL_NO_MLKEM */
+
+# if defined(__cplusplus)
+}  /* extern C */
+# endif
+
+#endif  /* OPENSSL_HEADER_MLKEM_H */
index 1f480d84d883494400b4b58fc1ca0759c8a3c7fe..2cb4bbc2cd196f8ceb561b4ca472a2ac7fa996a3 100644 (file)
@@ -42,7 +42,6 @@ static ossl_inline unsigned int constant_time_lt(unsigned int a,
 /* Convenience method for getting an 8-bit mask. */
 static ossl_inline unsigned char constant_time_lt_8(unsigned int a,
                                                     unsigned int b);
-
 /* Convenience method for uint32_t. */
 static ossl_inline uint32_t constant_time_lt_32(uint32_t a, uint32_t b);
 
index 73fb53bc5ff89f53ca4c993134e7f1719d6b3a5d..2507bb1887742615ac894534779e02fd5d3af77c 100644 (file)
 # define OSSL_TLS_GROUP_ID_ffdhe4096        0x0102
 # define OSSL_TLS_GROUP_ID_ffdhe6144        0x0103
 # define OSSL_TLS_GROUP_ID_ffdhe8192        0x0104
+ /*
+  * TODO(ML-KEM): Update to 513 as per IANA
+  * https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8
+  * Not done yet to not break interop testing with OQS test server; change when that
+  * gets updated in line with whatever IANA eventually defines
+  */
+# define OSSL_TLS_GROUP_ID_mlkem768 0x0768
 
 #endif
index 78099ecf659e4ce6066f85ebb676e40c91758c4a..6ca2b39efeb4177ab459cd26ec170c0b45ff18fe 100644 (file)
@@ -18,6 +18,7 @@
 #include "internal/tlsgroups.h"
 #include "prov/providercommon.h"
 #include "internal/e_os.h"
+#include "crypto/mlkem.h"
 
 /* If neither ec or dh is available then we have no TLS-GROUP capabilities */
 #if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH)
@@ -28,73 +29,76 @@ typedef struct tls_group_constants_st {
     int maxtls;              /* Maximum TLS version (or 0 for undefined) */
     int mindtls;             /* Minimum DTLS version, -1 unsupported */
     int maxdtls;             /* Maximum DTLS version (or 0 for undefined) */
+    int is_kem;              /* Indicates utility as KEM */
 } TLS_GROUP_CONSTANTS;
 
 static const TLS_GROUP_CONSTANTS group_list[] = {
     { OSSL_TLS_GROUP_ID_sect163k1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect163r1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect163r2, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect193r1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect193r2, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect233k1, 112, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect233r1, 112, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect239k1, 112, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect283k1, 128, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect283r1, 128, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect409k1, 192, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect409r1, 192, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect571k1, 256, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_sect571r1, 256, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp160k1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp160r1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp160r2, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp192k1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp192r1, 80, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp224k1, 112, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp224r1, 112, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_secp256k1, 128, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
-    { OSSL_TLS_GROUP_ID_secp256r1, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0 },
-    { OSSL_TLS_GROUP_ID_secp384r1, 192, TLS1_VERSION, 0, DTLS1_VERSION, 0 },
-    { OSSL_TLS_GROUP_ID_secp521r1, 256, TLS1_VERSION, 0, DTLS1_VERSION, 0 },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
+    { OSSL_TLS_GROUP_ID_secp256r1, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 },
+    { OSSL_TLS_GROUP_ID_secp384r1, 192, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 },
+    { OSSL_TLS_GROUP_ID_secp521r1, 256, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 },
     { OSSL_TLS_GROUP_ID_brainpoolP256r1, 128, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_brainpoolP384r1, 192, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
     { OSSL_TLS_GROUP_ID_brainpoolP512r1, 256, TLS1_VERSION, TLS1_2_VERSION,
-      DTLS1_VERSION, DTLS1_2_VERSION },
-    { OSSL_TLS_GROUP_ID_x25519, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0 },
-    { OSSL_TLS_GROUP_ID_x448, 224, TLS1_VERSION, 0, DTLS1_VERSION, 0 },
-    { OSSL_TLS_GROUP_ID_brainpoolP256r1_tls13, 128, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_brainpoolP384r1_tls13, 192, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_brainpoolP512r1_tls13, 256, TLS1_3_VERSION, 0, -1, -1 },
+      DTLS1_VERSION, DTLS1_2_VERSION, 0 },
+    { OSSL_TLS_GROUP_ID_x25519, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 },
+    { OSSL_TLS_GROUP_ID_x448, 224, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 },
+    { OSSL_TLS_GROUP_ID_brainpoolP256r1_tls13, 128, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_brainpoolP384r1_tls13, 192, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_brainpoolP512r1_tls13, 256, TLS1_3_VERSION, 0, -1, -1, 0 },
     /* Security bit values as given by BN_security_bits() */
-    { OSSL_TLS_GROUP_ID_ffdhe2048, 112, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_ffdhe3072, 128, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_ffdhe4096, 128, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1 },
-    { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1 },
+    { OSSL_TLS_GROUP_ID_ffdhe2048, 112, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_ffdhe3072, 128, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_ffdhe4096, 128, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1, 0 },
+    { OSSL_TLS_GROUP_ID_mlkem768, OSSL_MLKEM768_SECURITY_BITS,
+      TLS1_3_VERSION, 0, -1, -1, 1 },
 };
 
 #define TLS_GROUP_ENTRY(tlsname, realname, algorithm, idx) \
@@ -120,10 +124,12 @@ static const TLS_GROUP_CONSTANTS group_list[] = {
                         (unsigned int *)&group_list[idx].mindtls), \
         OSSL_PARAM_int(OSSL_CAPABILITY_TLS_GROUP_MAX_DTLS, \
                         (unsigned int *)&group_list[idx].maxdtls), \
+        OSSL_PARAM_int(OSSL_CAPABILITY_TLS_GROUP_IS_KEM, \
+                       (unsigned int *)&group_list[idx].is_kem), \
         OSSL_PARAM_END \
     }
 
-static const OSSL_PARAM param_group_list[][10] = {
+static const OSSL_PARAM param_group_list[][11] = {
 # ifndef OPENSSL_NO_EC
 #  ifndef OPENSSL_NO_EC2M
     TLS_GROUP_ENTRY("sect163k1", "sect163k1", "EC", 0),
@@ -204,6 +210,8 @@ static const OSSL_PARAM param_group_list[][10] = {
     TLS_GROUP_ENTRY("ffdhe6144", "ffdhe6144", "DH", 36),
     TLS_GROUP_ENTRY("ffdhe8192", "ffdhe8192", "DH", 37),
 # endif
+    /* TODO(ML-KEM): Decide final name, e.g., ML-KEM768 or MLKEM768 */
+    TLS_GROUP_ENTRY("MLKEM768", "MLKEM768", "ML-KEM-768", 38),
 };
 #endif /* !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH) */
 
index ccc1e8e7e9bb50bbb3b388f39f6b91022f438d76..9b8781e532ea1d7a035fcdbaf0350d9ff494ae19 100644 (file)
@@ -482,6 +482,7 @@ static const OSSL_ALGORITHM deflt_asym_kem[] = {
 # endif
     { PROV_NAMES_EC, "provider=default", ossl_ec_asym_kem_functions },
 #endif
+    { PROV_NAMES_MLKEM768, "provider=default", ossl_mlkem768_asym_kem_functions },
     { NULL, NULL, NULL }
 };
 
@@ -544,6 +545,8 @@ static const OSSL_ALGORITHM deflt_keymgmt[] = {
     { PROV_NAMES_SM2, "provider=default", ossl_sm2_keymgmt_functions,
       PROV_DESCS_SM2 },
 #endif
+    { PROV_NAMES_MLKEM768, "provider=default", ossl_mlkem768_keymgmt_functions,
+      PROV_DESCS_MLKEM768 },
     { NULL, NULL, NULL }
 };
 
index 3863d96a4073baf10a3eb0d6c9100d5b7885b0db..c156c3b2bf3333578fabcaedc123594232755ecd 100644 (file)
@@ -324,6 +324,7 @@ extern const OSSL_DISPATCH ossl_sm2_keymgmt_functions[];
 extern const OSSL_DISPATCH ossl_ml_dsa_44_keymgmt_functions[];
 extern const OSSL_DISPATCH ossl_ml_dsa_65_keymgmt_functions[];
 extern const OSSL_DISPATCH ossl_ml_dsa_87_keymgmt_functions[];
+extern const OSSL_DISPATCH ossl_mlkem768_keymgmt_functions[];
 
 /* Key Exchange */
 extern const OSSL_DISPATCH ossl_dh_keyexch_functions[];
@@ -400,6 +401,7 @@ extern const OSSL_DISPATCH ossl_sm2_asym_cipher_functions[];
 extern const OSSL_DISPATCH ossl_rsa_asym_kem_functions[];
 extern const OSSL_DISPATCH ossl_ecx_asym_kem_functions[];
 extern const OSSL_DISPATCH ossl_ec_asym_kem_functions[];
+extern const OSSL_DISPATCH ossl_mlkem768_asym_kem_functions[];
 
 /* Encoders */
 extern const OSSL_DISPATCH ossl_rsa_to_PKCS1_der_encoder_functions[];
diff --git a/providers/implementations/include/prov/mlkem.h b/providers/implementations/include/prov/mlkem.h
new file mode 100644 (file)
index 0000000..4eb4202
--- /dev/null
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#ifndef OSSL_INTERNAL_MLKEM_H
+# define OSSL_INTERNAL_MLKEM_H
+# pragma once
+
+# ifndef OPENSSL_NO_MLKEM
+
+#  include <stdint.h>
+#  include <crypto/mlkem.h>
+
+#  define MLKEM_KEY_TYPE_512     0
+#  define MLKEM_KEY_TYPE_768     1
+#  define MLKEY_KEY_TYPE_1024    2
+
+typedef struct mlkem768_key_st {
+    int keytype;
+    ossl_mlkem768_private_key seckey;
+    ossl_mlkem768_public_key pubkey;
+    uint8_t *encoded_pubkey;
+    int pubkey_initialized;
+    int seckey_initialized;
+    ossl_mlkem_ctx *mlkem_ctx;
+    void *provctx;
+} MLKEM768_KEY;
+
+# endif /* OPENSSL_NO_MLKEM */
+
+#endif /* OSSL_INTERNAL_MLKEM_H */
index 9280be0bbea855444cdb37105a52339edee48000..e556be86bb4695b746a7d89f95758577d00328c5 100644 (file)
 #define PROV_DESCS_ML_DSA_65 "OpenSSL ML-DSA-65 implementation"
 #define PROV_NAMES_ML_DSA_87 "ML-DSA-87:2.16.840.1.101.3.4.3.19:id-ml-dsa-87"
 #define PROV_DESCS_ML_DSA_87 "OpenSSL ML-DSA-87 implementation"
+#define PROV_NAMES_MLKEM768 "ML-KEM-768"
+#define PROV_DESCS_MLKEM768 "OpenSSL ML-KEM-768 implementation"
index 4a6a58ff654a8b73c07ac21ce472f2ef787ef7b7..b452323c219e9feba75f3bf7b00a4e5f8ef077e9 100644 (file)
@@ -4,6 +4,7 @@
 $RSA_KEM_GOAL=../../libdefault.a ../../libfips.a
 $EC_KEM_GOAL=../../libdefault.a
 $TEMPLATE_KEM_GOAL=../../libtemplate.a
+$ML_KEM_GOAL=../../libdefault.a
 
 SOURCE[$RSA_KEM_GOAL]=rsa_kem.c
 
@@ -15,3 +16,4 @@ IF[{- !$disabled{ec} -}]
 ENDIF
 
 SOURCE[$TEMPLATE_KEM_GOAL]=template_kem.c
+SOURCE[$ML_KEM_GOAL]=ml_kem.c
diff --git a/providers/implementations/kem/ml_kem.c b/providers/implementations/kem/ml_kem.c
new file mode 100644 (file)
index 0000000..99a55dd
--- /dev/null
@@ -0,0 +1,210 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#include <string.h>
+#include <openssl/crypto.h>
+#include <openssl/evp.h>
+#include <openssl/core_dispatch.h>
+#include <openssl/core_names.h>
+#include <openssl/params.h>
+#include <openssl/err.h>
+#include <openssl/proverr.h>
+#include "prov/provider_ctx.h"
+#include "prov/implementations.h"
+#include "prov/securitycheck.h"
+#include "prov/providercommon.h"
+#include "prov/mlkem.h"
+
+#define BUFSIZE 1000
+#if defined(NDEBUG) || defined(OPENSSL_NO_STDIO)
+/* TODO(ML-KEM) to remove or replace with TRACE */
+static void debug_print(char *fmt, ...)
+{
+}
+#else
+static void debug_print(char *fmt, ...)
+{
+    char out[BUFSIZE];
+    va_list argptr;
+
+    va_start(argptr, fmt);
+    vsnprintf(out, BUFSIZE, fmt, argptr);
+    va_end(argptr);
+    if (getenv("TEMPLATEKM"))
+        fprintf(stderr, "TEMPLATE_KM: %s", out);
+}
+#endif
+
+typedef struct {
+    OSSL_LIB_CTX *libctx;
+    MLKEM768_KEY *key;
+    int op;
+} PROV_MLKEM_CTX;
+
+static OSSL_FUNC_kem_newctx_fn mlkem_newctx;
+static OSSL_FUNC_kem_encapsulate_init_fn mlkem_encapsulate_init;
+static OSSL_FUNC_kem_encapsulate_fn mlkem_encapsulate;
+static OSSL_FUNC_kem_decapsulate_init_fn mlkem_decapsulate_init;
+static OSSL_FUNC_kem_decapsulate_fn mlkem_decapsulate;
+static OSSL_FUNC_kem_freectx_fn mlkem_freectx;
+static OSSL_FUNC_kem_set_ctx_params_fn mlkem_set_ctx_params;
+
+static void *mlkem_newctx(void *provctx)
+{
+    PROV_MLKEM_CTX *ctx = OPENSSL_zalloc(sizeof(*ctx));
+
+    debug_print("MLKEMKEM newctx called\n");
+    if (ctx == NULL)
+        return NULL;
+
+    ctx->libctx = PROV_LIBCTX_OF(provctx);
+
+    debug_print("MLKEMKEM newctx returns %p\n", ctx);
+    return ctx;
+}
+
+static void mlkem_freectx(void *vctx)
+{
+    PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx;
+
+    debug_print("MLKEMKEM freectx %p\n", ctx);
+    OPENSSL_free(ctx);
+}
+
+static int mlkem_init(void *vctx, int operation, void *vkey, void *vauth,
+                      ossl_unused const OSSL_PARAM params[])
+{
+    PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx;
+    MLKEM768_KEY *mlkemkey = vkey;
+
+    debug_print("MLKEMKEM init %p / %p\n", ctx, mlkemkey);
+    if (!ossl_prov_is_running())
+        return 0;
+
+    if (mlkemkey->keytype != MLKEM_KEY_TYPE_768)
+        return 0;
+
+    ctx->key = mlkemkey;
+    ctx->op = operation;
+    debug_print("MLKEMKEM init OK\n");
+    return 1;
+}
+
+static int mlkem_encapsulate_init(void *vctx, void *vkey,
+                                  const OSSL_PARAM params[])
+{
+    return mlkem_init(vctx, EVP_PKEY_OP_ENCAPSULATE, vkey, NULL, params);
+}
+
+static int mlkem_decapsulate_init(void *vctx, void *vkey,
+                                  const OSSL_PARAM params[])
+{
+    return mlkem_init(vctx, EVP_PKEY_OP_DECAPSULATE, vkey, NULL, params);
+}
+
+static int mlkem_set_ctx_params(void *vctx, const OSSL_PARAM params[])
+{
+    PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx;
+
+    debug_print("MLKEMKEM set ctx params %p\n", ctx);
+    if (ctx == NULL)
+        return 0;
+    if (params == NULL)
+        return 1;
+
+    debug_print("MLKEMKEM set ctx params OK\n");
+    return 1;
+}
+
+static const OSSL_PARAM known_settable_mlkem_ctx_params[] = {
+    OSSL_PARAM_END
+};
+
+static const OSSL_PARAM *mlkem_settable_ctx_params(ossl_unused void *vctx,
+                                                   ossl_unused void *provctx)
+{
+    return known_settable_mlkem_ctx_params;
+}
+
+static int mlkem_encapsulate(void *vctx, unsigned char *out, size_t *outlen,
+                             unsigned char *secret, size_t *secretlen)
+{
+    PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx;
+    int ret;
+
+    debug_print("MLKEMKEM encaps %p to %p\n", ctx, out);
+    if (outlen != NULL)
+        *outlen = OSSL_MLKEM768_CIPHERTEXT_BYTES;
+    if (secretlen != NULL)
+        *secretlen = OSSL_MLKEM768_SHARED_SECRET_BYTES;
+
+    if (out == NULL) {
+        debug_print("MLKEMKEM encaps outlens set to %ld and %ld\n", *outlen, *secretlen);
+        return 1;
+    }
+
+    if (ctx->key == NULL
+            || ctx->key->keytype != MLKEM_KEY_TYPE_768
+            || ctx->key->pubkey_initialized == 0
+            || secret == NULL)
+        return 0;
+
+    ret = ossl_mlkem768_encap(out, (uint8_t *)secret, &ctx->key->pubkey, ctx->key->mlkem_ctx);
+
+    debug_print("MLKEMKEM encaps returns %d\n", ret);
+    return ret;
+}
+
+static int mlkem_decapsulate(void *vctx, unsigned char *out, size_t *outlen,
+                             const unsigned char *in, size_t inlen)
+{
+    PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx;
+    int ret;
+
+    debug_print("MLKEMKEM decaps %p to %p\n", ctx, out);
+    debug_print("MLKEMKEM decaps inlen at %ld\n", inlen);
+    if (outlen != NULL)
+        *outlen = OSSL_MLKEM768_SHARED_SECRET_BYTES;
+
+    if (out == NULL) {
+        debug_print("MLKEMKEM decaps outlen set to %ld \n", *outlen);
+        return 1;
+    }
+
+    if (ctx->key == NULL
+            || ctx->key->keytype != MLKEM_KEY_TYPE_768
+            || ctx->key->seckey_initialized == 0
+            || in == NULL)
+        return 0;
+
+    if (inlen != OSSL_MLKEM768_CIPHERTEXT_BYTES)
+        return 0;
+
+    ret = ossl_mlkem768_decap((uint8_t *)out, (uint8_t *)in, inlen, &ctx->key->seckey,
+                              ctx->key->mlkem_ctx);
+
+    debug_print("MLKEMKEM decaps returns %d\n", ret);
+    return ret;
+}
+
+const OSSL_DISPATCH ossl_mlkem768_asym_kem_functions[] = {
+    { OSSL_FUNC_KEM_NEWCTX, (void (*)(void))mlkem_newctx },
+    { OSSL_FUNC_KEM_ENCAPSULATE_INIT,
+      (void (*)(void))mlkem_encapsulate_init },
+    { OSSL_FUNC_KEM_ENCAPSULATE, (void (*)(void))mlkem_encapsulate },
+    { OSSL_FUNC_KEM_DECAPSULATE_INIT,
+      (void (*)(void))mlkem_decapsulate_init },
+    { OSSL_FUNC_KEM_DECAPSULATE, (void (*)(void))mlkem_decapsulate },
+    { OSSL_FUNC_KEM_FREECTX, (void (*)(void))mlkem_freectx },
+    { OSSL_FUNC_KEM_SET_CTX_PARAMS,
+      (void (*)(void))mlkem_set_ctx_params },
+    { OSSL_FUNC_KEM_SETTABLE_CTX_PARAMS,
+      (void (*)(void))mlkem_settable_ctx_params },
+    OSSL_DISPATCH_END
+};
index edfb15f7e2d29d37131b02f422c3fbdb93beb879..b1ee39c4fe008396811b4dc9ddb9e09f04cb7bc8 100644 (file)
@@ -10,6 +10,7 @@ $MAC_GOAL=../../libdefault.a ../../libfips.a
 $RSA_GOAL=../../libdefault.a ../../libfips.a
 $TEMPLATE_GOAL=../../libtemplate.a
 $ML_DSA_GOAL=../../libdefault.a ../../libfips.a
+$MLKEM_GOAL=../../libdefault.a
 
 IF[{- !$disabled{dh} -}]
   SOURCE[$DH_GOAL]=dh_kmgmt.c
@@ -49,3 +50,4 @@ SOURCE[$TEMPLATE_GOAL]=template_kmgmt.c
 IF[{- !$disabled{'ml-dsa'} -}]
   SOURCE[$ML_DSA_GOAL]=ml_dsa_kmgmt.c
 ENDIF
+SOURCE[$MLKEM_GOAL]=mlkem_kmgmt.c
diff --git a/providers/implementations/keymgmt/mlkem_kmgmt.c b/providers/implementations/keymgmt/mlkem_kmgmt.c
new file mode 100644 (file)
index 0000000..02d7e5f
--- /dev/null
@@ -0,0 +1,590 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#include <openssl/core_dispatch.h>
+#include <openssl/core_names.h>
+#include <openssl/params.h>
+#include <openssl/err.h>
+#include <openssl/proverr.h>
+#include <openssl/rand.h>
+#include <openssl/self_test.h>
+#include "internal/param_build_set.h"
+#include <openssl/param_build.h>
+#include "prov/mlkem.h"
+#include "prov/implementations.h"
+#include "prov/providercommon.h"
+#include "prov/provider_ctx.h"
+#include "prov/securitycheck.h"
+#include <assert.h>
+
+#define BUFSIZE 1000
+#if defined(NDEBUG) || defined(OPENSSL_NO_STDIO)
+/* TODO(ML-KEM) to remove or replace with TRACE */
+static void debug_print(char *fmt, ...)
+{
+}
+#else
+static void debug_print(char *fmt, ...)
+{
+    char out[BUFSIZE];
+    va_list argptr;
+
+    va_start(argptr, fmt);
+    vsnprintf(out, BUFSIZE, fmt, argptr);
+    va_end(argptr);
+    if (getenv("TEMPLATEKM"))
+        fprintf(stderr, "TEMPLATE_KM: %s", out);
+}
+#endif
+
+static void print_hex(const uint8_t *data, int len, const char *msg)
+{
+#ifndef NDEBUG
+    if (msg)
+        printf("%s: \n", msg);
+    BIO_dump_fp(stdout, data, len);
+    printf("\n\n");
+#endif
+}
+
+static OSSL_FUNC_keymgmt_new_fn mlkem_new;
+static OSSL_FUNC_keymgmt_free_fn mlkem_free;
+static OSSL_FUNC_keymgmt_gen_init_fn mlkem_gen_init;
+static OSSL_FUNC_keymgmt_gen_fn mlkem_gen;
+static OSSL_FUNC_keymgmt_gen_cleanup_fn mlkem_gen_cleanup;
+static OSSL_FUNC_keymgmt_gen_set_params_fn mlkem_gen_set_params;
+static OSSL_FUNC_keymgmt_gen_settable_params_fn mlkem_gen_settable_params;
+static OSSL_FUNC_keymgmt_get_params_fn mlkem_get_params;
+static OSSL_FUNC_keymgmt_gettable_params_fn mlkem_gettable_params;
+static OSSL_FUNC_keymgmt_set_params_fn mlkem_set_params;
+static OSSL_FUNC_keymgmt_settable_params_fn mlkem_settable_params;
+static OSSL_FUNC_keymgmt_has_fn mlkem_has;
+static OSSL_FUNC_keymgmt_match_fn mlkem_match;
+
+/* implement only when encode/decode logic becomes required/standardized: */
+#ifdef UNDEF
+static OSSL_FUNC_keymgmt_import_fn mlkem_import;
+static OSSL_FUNC_keymgmt_export_fn mlkem_export;
+static OSSL_FUNC_keymgmt_import_types_fn mlkem_imexport_types;
+static OSSL_FUNC_keymgmt_export_types_fn mlkem_imexport_types;
+static OSSL_FUNC_keymgmt_dup_fn mlkem_dup;
+#endif /* UNDEF */
+
+struct mlkem_gen_ctx {
+    void *provctx;
+    int selection;
+};
+
+static void *mlkem_new(void *provctx)
+{
+    MLKEM768_KEY *key = NULL;
+
+    debug_print("MLKEMKM new key req\n");
+    if (!ossl_prov_is_running())
+        return 0;
+
+    key = OPENSSL_zalloc(sizeof(MLKEM768_KEY));
+    if (key != NULL) {
+        key->keytype = MLKEM_KEY_TYPE_768; /* TODO(ML-KEM) any type */
+        key->provctx = provctx;
+        /*
+         * ideally, this is a one-time allocation and ctx that should be within the
+         * provider context: OK to move it there to improve performance?? It would be
+         * the first algorithmspecific context stored: Feels weird (TODO(ML-KEM)).
+         */
+        key->mlkem_ctx = ossl_mlkem_newctx(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), NULL);
+        if (key->mlkem_ctx == NULL) {
+            OPENSSL_free(key);
+            key = NULL;
+        }
+    }
+
+    debug_print("MLKEMKM new key = %p\n", key);
+    return key;
+}
+
+static void mlkem_free(void *vkey)
+{
+    MLKEM768_KEY *mkey = (MLKEM768_KEY *)vkey;
+
+    debug_print("MLKEMKM free key %p\n", mkey);
+    if (mkey == NULL)
+        return;
+    ossl_mlkem_ctx_free(mkey->mlkem_ctx);
+    OPENSSL_free(mkey->encoded_pubkey);
+    OPENSSL_free(mkey);
+}
+
+static int mlkem_has(const void *keydata, int selection)
+{
+    const MLKEM768_KEY *key = keydata;
+    int ok = 0;
+
+    debug_print("MLKEMKM has %p\n", key);
+    if (ossl_prov_is_running() && key != NULL) {
+        /*
+         * ML-KEM keys always have all the parameters they need (i.e. none).
+         * Therefore we always return with 1, if asked about parameters.
+         */
+        ok = 1;
+
+        if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0)
+            ok = ok && key->pubkey_initialized == 1;
+
+        if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0)
+            ok = ok && key->seckey_initialized == 1;
+    }
+    debug_print("MLKEMKM has result %d\n", ok);
+    return ok;
+}
+
+static int mlkem_match(const void *keydata1, const void *keydata2, int selection)
+{
+    const MLKEM768_KEY *key1 = keydata1;
+    const MLKEM768_KEY *key2 = keydata2;
+    int ok = 1;
+
+    debug_print("MLKEMKM matching %p and %p\n", key1, key2);
+    if (!ossl_prov_is_running())
+        return 0;
+
+    if ((selection & OSSL_KEYMGMT_SELECT_DOMAIN_PARAMETERS) != 0)
+        ok = ok && key1->keytype == key2->keytype;
+
+    /* TODO(ML-KEM) */
+    debug_print("MLKEMKM matching for now NOT YET IMPLEMENTED\n");
+
+/* TODO(ML-KEM) template code to be completed as and when needed: */
+#ifdef UNDEF
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
+        int key_checked = 0;
+
+        if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) {
+            const uint8_t *pa = key1->pubkey;
+            const uint8_t *pb = key2->pubkey;
+
+            if (pa != NULL && pb != NULL) {
+                ok = ok
+                    && key1->keytype == key2->keytype
+                    && CRYPTO_memcmp(pa, pb, MLKEM768_PUBLICKEYBYTES) == 0;
+                key_checked = 1;
+            }
+        }
+        if (!key_checked
+            && (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) {
+            const uint8_t *pa = key1->seckey;
+            const uint8_t *pb = key2->seckey;
+
+            if (pa != NULL && pb != NULL) {
+                ok = ok
+                    && key1->keytype == key2->keytype
+                    && CRYPTO_memcmp(pa, pb, MLKEM768_SECRETKEYBYTES) == 0;
+                key_checked = 1;
+            }
+        }
+        ok = ok && key_checked;
+    }
+#endif /* UNDEF */
+    debug_print("MLKEMKM match result %d\n", ok);
+    return ok;
+}
+
+/* TODO(ML-KEM) as and when encode/decode becomes needed/standardized */
+#ifdef UNDEF
+static int key_to_params(MLKEM768_KEY *key, OSSL_PARAM_BLD *tmpl,
+                         OSSL_PARAM params[], int include_private)
+{
+    if (key == NULL)
+        return 0;
+
+    if (key->keytype != MLKEM_KEY_TYPE_768)
+        return 0;
+
+    if (!ossl_param_build_set_octet_string(tmpl, params,
+                                           OSSL_PKEY_PARAM_PUB_KEY,
+                                           key->pubkey, MLKEM768_PUBLICKEYBYTES))
+        return 0;
+
+    if (include_private
+        && key->seckey != NULL
+        && !ossl_param_build_set_octet_string(tmpl, params,
+                                              OSSL_PKEY_PARAM_PRIV_KEY,
+                                              key->seckey, MLKEM768_SECRETKEYBYTES))
+        return 0;
+
+    return 1;
+}
+
+static int mlkem_export(void *key, int selection, OSSL_CALLBACK *param_cb,
+                        void *cbarg)
+{
+    MLKEM768_KEY *mkey = key;
+    OSSL_PARAM_BLD *tmpl;
+    OSSL_PARAM *params = NULL;
+    int ret = 0;
+
+    debug_print("MLKEMKM export %p\n", key);
+    if (!ossl_prov_is_running() || key == NULL)
+        return 0;
+
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0)
+        return 0;
+
+    tmpl = OSSL_PARAM_BLD_new();
+    if (tmpl == NULL)
+        return 0;
+
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
+        int include_private = ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0);
+
+        if (!key_to_params(mkey, tmpl, NULL, include_private))
+            goto err;
+    }
+
+    params = OSSL_PARAM_BLD_to_param(tmpl);
+    if (params == NULL)
+        goto err;
+
+    ret = param_cb(params, cbarg);
+    OSSL_PARAM_free(params);
+err:
+    OSSL_PARAM_BLD_free(tmpl);
+    debug_print("MLKEMKM export result %d\n", ret);
+    return ret;
+}
+
+static int ossl_mlkem_key_fromdata(MLKEM768_KEY *key,
+                                   const OSSL_PARAM params[],
+                                   int include_private)
+{
+    size_t privkeylen = 0, pubkeylen = 0;
+    const OSSL_PARAM *param_priv_key = NULL, *param_pub_key;
+    unsigned char *pubkey;
+
+    if (key == NULL)
+        return 0;
+
+    if (key->keytype != MLKEM_KEY_TYPE_768)
+        return 0;
+
+    param_pub_key = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY);
+    if (include_private)
+        param_priv_key =
+            OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PRIV_KEY);
+
+    if (param_pub_key == NULL && param_priv_key == NULL)
+        return 0;
+
+    if (param_priv_key != NULL) {
+        if (!OSSL_PARAM_get_octet_string(param_priv_key,
+                                         (void **)&key->seckey,
+                                         MLKEM768_SECRETKEYBYTES,
+                                         &privkeylen))
+            return 0;
+        if (privkeylen != MLKEM768_SECRETKEYBYTES) {
+            debug_print("sec key len mismatch in import: %ld vs %d: HOWCAN?\n",
+                        privkeylen, MLKEM768_SECRETKEYBYTES);
+            OPENSSL_secure_clear_free(key->seckey, privkeylen);
+            key->seckey = NULL;
+            return 0;
+        }
+    }
+
+    pubkey = key->pubkey;
+    if (param_pub_key != NULL
+        && !OSSL_PARAM_get_octet_string(param_pub_key,
+                                        (void **)&pubkey,
+                                        ossl_mlkem768_PUBLIC_KEY_BYTES,
+                                        &pubkeylen))
+        return 0;
+
+    if ((param_pub_key != NULL && pubkeylen != ossl_mlkem768_PUBLIC_KEY_BYTES)) {
+        debug_print("sec key len mismatch in import: %ld vs %d: HOWCAN?\n",
+                    pubkeylen, ossl_mlkem768_PUBLIC_KEY_BYTES);
+        return 0;
+    }
+
+    /*
+     * TBD if hybrid logic is not getting cleanly implemented in separate logic:
+     * reconstitute (only) classic part here
+     */
+
+    return 1;
+}
+
+static int mlkem_import(void *key, int selection, const OSSL_PARAM params[])
+{
+    MLKEM768_KEY *mkey = key;
+    int ok = 1;
+    int include_private;
+
+    debug_print("MLKEMKM import %p\n", mkey);
+    if (!ossl_prov_is_running() || key == NULL)
+        return 0;
+
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0)
+        return 0;
+
+    include_private = selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY ? 1 : 0;
+    ok = ok && ossl_mlkem_key_fromdata(mkey, params, include_private);
+
+    debug_print("MLKEMKM import result %d\n", ok);
+    return ok;
+}
+
+# define MLKEM768_KEY_TYPES()                                           \
+    OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PUB_KEY, NULL, 0),          \
+    OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0)
+
+static const OSSL_PARAM mlkem_key_types[] = {
+    MLKEM768_KEY_TYPES(),
+    OSSL_PARAM_END
+};
+
+static const OSSL_PARAM *mlkem_imexport_types(int selection)
+{
+    debug_print("MLKEMKM getting imexport types\n");
+    if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
+        return mlkem_key_types;
+    return NULL;
+}
+
+#endif /* UNDEF */
+
+static int mlkem_get_params(void *key, OSSL_PARAM params[])
+{
+    MLKEM768_KEY *mkey = key;
+    OSSL_PARAM *p;
+
+    debug_print("MLKEMKM get params %p\n", mkey);
+    if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_BITS)) != NULL
+        && !OSSL_PARAM_set_int(p, sizeof(ossl_mlkem768_private_key) * 8))
+        return 0;
+    if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_BITS)) != NULL
+        && !OSSL_PARAM_set_int(p, OSSL_MLKEM768_SECURITY_BITS))
+        return 0;
+    if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_MAX_SIZE)) != NULL
+        && !OSSL_PARAM_set_int(p, OSSL_MLKEM768_CIPHERTEXT_BYTES))
+        return 0;
+    if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY)) != NULL
+        && mkey->pubkey_initialized == 1) {
+        if (!OSSL_PARAM_set_octet_string(p, mkey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES))
+            return 0;
+        debug_print("MLKEMKM got encoded public key of len %d\n", OSSL_MLKEM768_PUBLIC_KEY_BYTES);
+        print_hex(mkey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES, "enc PK");
+    }
+
+    debug_print("MLKEMKM get params OK\n");
+    return 1;
+}
+
+static const OSSL_PARAM mlkem_gettable_params_arr[] = {
+    OSSL_PARAM_int(OSSL_PKEY_PARAM_BITS, NULL),
+    OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_BITS, NULL),
+    OSSL_PARAM_int(OSSL_PKEY_PARAM_MAX_SIZE, NULL),
+    OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
+    OSSL_PARAM_END
+};
+
+static const OSSL_PARAM *mlkem_gettable_params(void *provctx)
+{
+    debug_print("MLKEMKM gettable params called\n");
+    return mlkem_gettable_params_arr;
+}
+
+static int mlkem_set_params(void *key, const OSSL_PARAM params[])
+{
+    MLKEM768_KEY *mkey = key;
+    const OSSL_PARAM *p;
+
+    debug_print("MLKEMKM set params called for %p\n", mkey);
+    if (params == NULL)
+        return 1;
+
+    p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY);
+    if (p != NULL) {
+        size_t len_stored;
+
+        if (p->data_size != OSSL_MLKEM768_PUBLIC_KEY_BYTES
+                || !OSSL_PARAM_get_octet_string(p, (void **)&mkey->encoded_pubkey,
+                                                OSSL_MLKEM768_PUBLIC_KEY_BYTES,
+                                                &len_stored))
+            return 0;
+        debug_print("encoded pub key successfully stored with %ld bytes\n", len_stored);
+        if (!ossl_mlkem768_recreate_public_key(mkey->encoded_pubkey, &mkey->pubkey,
+                                               mkey->mlkem_ctx))
+            return 0;
+        mkey->pubkey_initialized = 1;
+    }
+
+    debug_print("MLKEMKM set params OK\n");
+    return 1;
+}
+
+static const OSSL_PARAM mlkem_settable_params_arr[] = {
+    OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
+    OSSL_PARAM_END
+};
+
+static const OSSL_PARAM *mlkem_settable_params(void *provctx)
+{
+    debug_print("MLKEMKM settable params called\n");
+    return mlkem_settable_params_arr;
+}
+
+static void *mlkem_gen_init(void *provctx, int selection,
+                            const OSSL_PARAM params[])
+{
+    struct mlkem_gen_ctx *gctx = NULL;
+
+    debug_print("MLKEMKM gen init called for %p\n", provctx);
+    if (!ossl_prov_is_running())
+        return NULL;
+
+    if ((gctx = OPENSSL_zalloc(sizeof(*gctx))) != NULL) {
+        gctx->provctx = provctx;
+        gctx->selection = selection;
+    }
+    if (!mlkem_gen_set_params(gctx, params)) {
+        OPENSSL_free(gctx);
+        gctx = NULL;
+    }
+    debug_print("MLKEMKM gen init returns %p\n", gctx);
+    return gctx;
+}
+
+static int mlkem_gen_set_params(void *genctx, const OSSL_PARAM params[])
+{
+    struct mlkem_gen_ctx *gctx = genctx;
+
+    if (gctx == NULL)
+        return 0;
+
+    debug_print("MLKEMKM empty gen_set params called for %p\n", gctx);
+    return 1;
+}
+
+static const OSSL_PARAM *mlkem_gen_settable_params(ossl_unused void *genctx,
+                                                   ossl_unused void *provctx)
+{
+    static OSSL_PARAM settable[] = {
+        OSSL_PARAM_END
+    };
+    return settable;
+}
+
+static void *mlkem_gen(void *vctx, OSSL_CALLBACK *osslcb, void *cbarg)
+{
+    struct mlkem_gen_ctx *gctx = (struct mlkem_gen_ctx *)vctx;
+    MLKEM768_KEY *mkey;
+
+    debug_print("MLKEMKM gen called for %p\n", gctx);
+    if (gctx == NULL)
+        return NULL;
+
+    if ((mkey = mlkem_new(gctx->provctx)) == NULL) {
+        ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
+        return NULL;
+    }
+
+    /* If we're doing parameter generation then we just return a blank key */
+    if ((gctx->selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0) {
+        debug_print("MLKEMKM gen returns blank %p\n", mkey);
+        return mkey;
+    }
+
+    mkey->keytype = MLKEM_KEY_TYPE_768;
+
+    if (mkey->encoded_pubkey == NULL) {
+        mkey->encoded_pubkey = OPENSSL_malloc(OSSL_MLKEM768_PUBLIC_KEY_BYTES);
+        if (mkey->encoded_pubkey == NULL)
+            goto err;
+    }
+
+    if (!ossl_mlkem768_generate_key(mkey->encoded_pubkey, NULL, &mkey->seckey,
+                                    mkey->mlkem_ctx)
+        || !ossl_mlkem768_public_from_private(&mkey->pubkey, &mkey->seckey))
+        goto err;
+
+    mkey->seckey_initialized = 1;
+    mkey->pubkey_initialized = 1;
+
+    debug_print("MLKEMKM gen returns set %p\n", mkey);
+    return mkey;
+
+err:
+    OPENSSL_free(mkey);
+    return NULL;
+}
+
+static void mlkem_gen_cleanup(void *genctx)
+{
+    struct mlkem_gen_ctx *gctx = genctx;
+
+    debug_print("MLKEMKM gen cleanup for %p\n", gctx);
+    OPENSSL_free(gctx);
+}
+
+static void *mlkem_dup(const void *vsrckey, int selection)
+{
+    const MLKEM768_KEY *srckey = (const MLKEM768_KEY *)vsrckey;
+    MLKEM768_KEY *dstkey;
+
+    debug_print("MLKEMKM dup called for %p\n", srckey);
+    if (!ossl_prov_is_running())
+        return NULL;
+
+    dstkey = mlkem_new(srckey->provctx);
+    if (dstkey == NULL)
+        return NULL;
+
+    dstkey->keytype = srckey->keytype;
+    if (srckey->pubkey_initialized == 1
+            && (selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) {
+        memcpy((void *)&dstkey->pubkey, (void *)&srckey->pubkey, sizeof(srckey->pubkey));
+        dstkey->encoded_pubkey = OPENSSL_malloc(OSSL_MLKEM768_PUBLIC_KEY_BYTES);
+        if (srckey->encoded_pubkey != NULL)
+            memcpy(dstkey->encoded_pubkey, srckey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES);
+        dstkey->pubkey_initialized = 1;
+    }
+    if (srckey->seckey_initialized == 1
+            && (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) {
+        memcpy((void *)&dstkey->seckey, (void *)&srckey->seckey, sizeof(srckey->seckey));
+        dstkey->seckey_initialized = 1;
+    }
+
+    debug_print("MLKEMKM dup returns %p\n", dstkey);
+    return dstkey;
+}
+
+const OSSL_DISPATCH ossl_mlkem768_keymgmt_functions[] = {
+    { OSSL_FUNC_KEYMGMT_NEW, (void (*)(void))mlkem_new },
+    { OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))mlkem_free },
+    { OSSL_FUNC_KEYMGMT_GET_PARAMS, (void (*) (void))mlkem_get_params },
+    { OSSL_FUNC_KEYMGMT_GETTABLE_PARAMS, (void (*) (void))mlkem_gettable_params },
+    { OSSL_FUNC_KEYMGMT_SET_PARAMS, (void (*) (void))mlkem_set_params },
+    { OSSL_FUNC_KEYMGMT_SETTABLE_PARAMS, (void (*) (void))mlkem_settable_params },
+    { OSSL_FUNC_KEYMGMT_HAS, (void (*)(void))mlkem_has },
+    { OSSL_FUNC_KEYMGMT_MATCH, (void (*)(void))mlkem_match },
+    { OSSL_FUNC_KEYMGMT_GEN_INIT, (void (*)(void))mlkem_gen_init },
+    { OSSL_FUNC_KEYMGMT_GEN_SET_PARAMS, (void (*)(void))mlkem_gen_set_params },
+    { OSSL_FUNC_KEYMGMT_GEN_SETTABLE_PARAMS,
+      (void (*)(void))mlkem_gen_settable_params },
+    { OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))mlkem_gen },
+    { OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))mlkem_gen_cleanup },
+    { OSSL_FUNC_KEYMGMT_DUP, (void (*)(void))mlkem_dup },
+    /*
+     * TODO(ML-KEM) don't do for now, see https://github.com/openssl/private/issues/698
+     *  { OSSL_FUNC_KEYMGMT_IMPORT_TYPES, (void (*)(void))mlkem_imexport_types },
+     *  { OSSL_FUNC_KEYMGMT_EXPORT_TYPES, (void (*)(void))mlkem_imexport_types },
+     *  { OSSL_FUNC_KEYMGMT_IMPORT, (void (*)(void))mlkem_export },
+     *  { OSSL_FUNC_KEYMGMT_EXPORT, (void (*)(void))mlkem_import },
+     */
+    OSSL_DISPATCH_END
+};
index 5e511590d342b48235e229c6b8762bd8a9e973de..ad3831d5ed20f8b228ec0b5be1b08fd3ab30de37 100644 (file)
@@ -1011,6 +1011,13 @@ IF[{- !$disabled{tests} -}]
     INCLUDE[asn1_dsa_internal_test]=.. ../include ../apps/include
     DEPEND[asn1_dsa_internal_test]=../libcrypto.a libtestutil.a
 
+    IF[{- !$disabled{'mlkem'} -}]
+      PROGRAMS{noinst}=mlkem_internal_test
+      SOURCE[mlkem_internal_test]=mlkem_internal_test.c
+      INCLUDE[mlkem_internal_test]=../include ../apps/include
+      DEPEND[mlkem_internal_test]=../libcrypto.a libtestutil.a
+    ENDIF
+
     SOURCE[keymgmt_internal_test]=keymgmt_internal_test.c
     INCLUDE[keymgmt_internal_test]=.. ../include ../apps/include
     DEPEND[keymgmt_internal_test]=../libcrypto.a libtestutil.a
index b9124d02b56d88dc8415bc0ac6575925a69860a9..a50ba9b43822a7d371aa94055cc73860a3a50bd1 100644 (file)
@@ -5912,6 +5912,109 @@ static int test_invalid_ctx_for_digest(void)
     return ret;
 }
 
+static int test_ml_kem(void)
+{
+    EVP_PKEY *akey, *bkey = NULL;
+    int res = 0;
+    size_t publen;
+    unsigned char *rawpub = NULL;
+    EVP_PKEY_CTX *ctx = NULL;
+    unsigned char *wrpkey = NULL, *agenkey = NULL, *bgenkey = NULL;
+    size_t wrpkeylen, agenkeylen, bgenkeylen, i;
+
+    /* Generate Alice's key */
+    akey = EVP_PKEY_Q_keygen(testctx, NULL, "ML-KEM-768");
+    if (!TEST_ptr(akey))
+        goto err;
+
+    /* Get the raw public key */
+    publen = EVP_PKEY_get1_encoded_public_key(akey, &rawpub);
+    if (!TEST_size_t_gt(publen, 0))
+        goto err;
+
+    /* Create Bob's key and populate it with Alice's public key data */
+    bkey = EVP_PKEY_new();
+    if (!TEST_ptr(bkey))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_copy_parameters(bkey, akey), 0))
+        goto err;
+
+    if (!TEST_true(EVP_PKEY_set1_encoded_public_key(bkey, rawpub, publen)))
+        goto err;
+
+    /* Encapsulate Bob's key */
+    ctx = EVP_PKEY_CTX_new_from_pkey(testctx, bkey, NULL);
+    if (!TEST_ptr(ctx))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_encapsulate_init(ctx, NULL), 0))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_encapsulate(ctx, NULL, &wrpkeylen, NULL,
+                                          &bgenkeylen), 0))
+        goto err;
+
+    if (!TEST_size_t_gt(wrpkeylen, 0) || !TEST_size_t_gt(bgenkeylen, 0))
+        goto err;
+
+    wrpkey = OPENSSL_zalloc(wrpkeylen);
+    bgenkey = OPENSSL_zalloc(bgenkeylen);
+    if (!TEST_ptr(wrpkey) || !TEST_ptr(bgenkey))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_encapsulate(ctx, wrpkey, &wrpkeylen, bgenkey,
+                                          &bgenkeylen), 0))
+        goto err;
+
+    EVP_PKEY_CTX_free(ctx);
+
+    /* Alice now decapsulates Bob's key */
+    ctx = EVP_PKEY_CTX_new_from_pkey(testctx, akey, NULL);
+    if (!TEST_ptr(ctx))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_decapsulate_init(ctx, NULL), 0))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_decapsulate(ctx, NULL, &agenkeylen, wrpkey,
+                                          wrpkeylen), 0))
+        goto err;
+
+    if (!TEST_size_t_gt(agenkeylen, 0))
+        goto err;
+
+    agenkey = OPENSSL_zalloc(agenkeylen);
+    if (!TEST_ptr(agenkey))
+        goto err;
+
+    if (!TEST_int_gt(EVP_PKEY_decapsulate(ctx, agenkey, &agenkeylen, wrpkey,
+                                          wrpkeylen), 0))
+        goto err;
+
+    /* Hopefully we ended up with a shared key */
+    if (!TEST_mem_eq(agenkey, agenkeylen, bgenkey, bgenkeylen))
+        goto err;
+
+    /* Verify we generated a non-zero shared key */
+    for (i = 0; i < agenkeylen; i++)
+        if (agenkey[i] != 0)
+            break;
+    if (!TEST_size_t_ne(i, agenkeylen))
+        return 0;
+
+    res = 1;
+ err:
+    EVP_PKEY_CTX_free(ctx);
+    EVP_PKEY_free(akey);
+    EVP_PKEY_free(bkey);
+    OPENSSL_free(rawpub);
+    OPENSSL_free(wrpkey);
+    OPENSSL_free(agenkey);
+    OPENSSL_free(bgenkey);
+    return res;
+}
+
 static int test_evp_cipher_pipeline(void)
 {
     OSSL_PROVIDER *fake_pipeline = NULL;
@@ -6310,6 +6413,7 @@ int setup_tests(void)
 #endif
 
     ADD_TEST(test_invalid_ctx_for_digest);
+    ADD_TEST(test_ml_kem);
 
     ADD_TEST(test_evp_cipher_pipeline);
 
diff --git a/test/mlkem_internal_test.c b/test/mlkem_internal_test.c
new file mode 100644 (file)
index 0000000..9dbd795
--- /dev/null
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#include <openssl/opensslconf.h>
+#ifndef OPENSSL_NO_STDIO
+# include <stdio.h>
+#endif
+
+#include <crypto/mlkem.h>
+
+#include <string.h>
+#include "testutil.h"
+#include "testutil/output.h"
+
+int main(void)
+{
+    uint8_t out_encoded_public_key[OSSL_MLKEM768_PUBLIC_KEY_BYTES];
+    uint8_t out_ciphertext[OSSL_MLKEM768_CIPHERTEXT_BYTES];
+    uint8_t out_shared_secret[OSSL_MLKEM768_SHARED_SECRET_BYTES];
+    uint8_t out_shared_secret2[OSSL_MLKEM768_SHARED_SECRET_BYTES];
+    ossl_mlkem768_private_key private_key;
+    ossl_mlkem768_public_key public_key;
+    ossl_mlkem768_public_key recreated_public_key;
+    uint8_t *p1, *p2;
+    ossl_mlkem_ctx *mlkem_ctx = ossl_mlkem_newctx(NULL, NULL);
+    int ret = 1;
+
+    /* enable TEST_* API */
+    test_open_streams();
+
+    /* first, generate a key pair */
+    if (!ossl_mlkem768_generate_key(out_encoded_public_key, NULL,
+                                    &private_key, mlkem_ctx)) {
+        ret = -1;
+        goto end;
+    }
+    /* public key component to be created from private key */
+    if (!ossl_mlkem768_public_from_private(&public_key, &private_key)) {
+        ret = -2;
+        goto end;
+    }
+    /* try to re-create public key structure from encoded public key */
+    if (!ossl_mlkem768_recreate_public_key(out_encoded_public_key,
+                                           &recreated_public_key, mlkem_ctx)) {
+        ret = -3;
+        goto end;
+    }
+    /* validate identity of both public key structures */
+    p1 = (uint8_t *)&public_key;
+    p2 = (uint8_t *)&recreated_public_key;
+    if (!TEST_int_eq(memcmp(p1, p2, sizeof(public_key)), 0)) {
+        ret = -4;
+        goto end;
+    }
+    /* encaps - decaps test: validate shared secret identity */
+    if (!ossl_mlkem768_encap(out_ciphertext, out_shared_secret,
+                             &recreated_public_key, mlkem_ctx)) {
+        ret = -5;
+        goto end;
+    }
+    if (!ossl_mlkem768_decap(out_shared_secret2, out_ciphertext,
+                             OSSL_MLKEM768_CIPHERTEXT_BYTES, &private_key, mlkem_ctx)) {
+        ret = -6;
+        goto end;
+    }
+    if (!TEST_int_eq(memcmp(out_shared_secret, out_shared_secret2,
+                            OSSL_MLKEM768_SHARED_SECRET_BYTES), 0)) {
+        ret = -7;
+        goto end;
+    }
+    /* so far so good, now a quick negative test by breaking the ciphertext */
+    out_ciphertext[0]++;
+    if (!ossl_mlkem768_decap(out_shared_secret2, out_ciphertext,
+                             OSSL_MLKEM768_CIPHERTEXT_BYTES, &private_key, mlkem_ctx))
+        goto end;
+    /* If decap passed, ensure we at least have a mismatch */
+    if (!TEST_int_ne(memcmp(out_shared_secret, out_shared_secret2,
+                            OSSL_MLKEM768_SHARED_SECRET_BYTES), 0))
+        ret = -8;
+
+end:
+    ossl_mlkem_ctx_free(mlkem_ctx);
+    return ret;
+}