From: Michael Baentsch <57787676+baentsch@users.noreply.github.com> Date: Mon, 11 Nov 2024 08:08:06 +0000 (+0100) Subject: Add ML-KEM-768 implementation X-Git-Tag: openssl-3.5.0-alpha1~553 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=96a079a03ff1239abbfd877b8dab91ba657fc4d1;p=thirdparty%2Fopenssl.git Add ML-KEM-768 implementation Based on code from BoringSSL covered under Google CCLA Original code at https://boringssl.googlesource.com/boringssl/+/HEAD/crypto/mlkem - VSCode automatic formatting (andrewd@openssl.org) - Just do some basic formatting to make diffs easier to read later: convert from 2 to 4 spaces, add newlines after function declarations, and move function open curly brace to new line (andrewd@openssl.org) - Move variable init to beginning of each function (andrewd@openssl.org) - Replace CBB API - Fixing up constants and parameter lists - Replace BORINGSSL_keccak calls with EVP calls - Added library symbols and low-level test case - Switch boringssl constant time routines for OpenSSL ones - Data type assertion and negative test added - Moved mlkem.h to include/crypto - Changed function naming to be in line with ossl convention - Remove Google license terms based on CCLA - Add constant_time_lt_32 - Convert asserts to ossl_asserts where possible - Add bssl keccak, pubK recreation, formatting - Add provider interface to utilize mlkem768 code enabling TLS1.3 use - Revert to OpenSSL DigestXOF - Use EVP_MD_xof() to determine digest finalisation (pauli@openssl.org) - Change APIs to return error codes; reference new IANA number; move static asserts to one place - Remove boringssl keccak for good - Fix coding style and return value checks - ANSI C compatibility changes - Remove static cache objects - All internal retval functions used leading to some new retval functions Reviewed-by: Tomas Mraz Reviewed-by: Matt Caswell (Merged from https://github.com/openssl/openssl/pull/25848) --- diff --git a/Configure b/Configure index 98ad2dc8248..92fd97fd2ea 100755 --- a/Configure +++ b/Configure @@ -487,6 +487,7 @@ my @disablables = ( "md4", "mdc2", "ml-dsa", + "mlkem", "module", "msan", "multiblock", diff --git a/crypto/build.info b/crypto/build.info index e476b678da3..72d5305616b 100644 --- a/crypto/build.info +++ b/crypto/build.info @@ -2,7 +2,7 @@ # there for further explanations. SUBDIRS=objects buffer bio stack lhash hashtable rand evp asn1 pem x509 conf \ txt_db pkcs7 pkcs12 ui kdf store property \ - md2 md4 md5 sha mdc2 hmac ripemd whrlpool poly1305 \ + md2 md4 md5 sha mdc2 mlkem hmac ripemd whrlpool poly1305 \ siphash sm3 des aes rc2 rc4 rc5 idea aria bf cast camellia \ seed sm4 chacha modes bn ec rsa dsa dh sm2 dso engine \ err comp http ocsp cms ts srp cmac ct async ess crmf cmp encode_decode \ diff --git a/crypto/mlkem/build.info b/crypto/mlkem/build.info new file mode 100644 index 00000000000..5d4ebe31b4a --- /dev/null +++ b/crypto/mlkem/build.info @@ -0,0 +1,3 @@ +LIBS = ../../libcrypto + +SOURCE[../../libcrypto] = mlkem768.c diff --git a/crypto/mlkem/mlkem768.c b/crypto/mlkem/mlkem768.c new file mode 100644 index 00000000000..8e038f4a4cb --- /dev/null +++ b/crypto/mlkem/mlkem768.c @@ -0,0 +1,1197 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +/* Copyright (c) 2024, Google Inc. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef NDEBUG +# include +#endif + +#ifndef OPENSSL_NO_MLKEM + +/* Constants that are common across all sizes. */ +# define DEGREE 256 +static const size_t kBarrettMultiplier = 5039; +static const unsigned kBarrettShift = 24; +static const uint16_t kPrime = 3329; +static const int kLog2Prime = 12; +static const uint16_t kHalfPrime = (/* kPrime= */ 3329 - 1) / 2; + +/* + * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th + * root of unity. + */ +static const uint16_t kInverseDegree = 3303; + +/* Rank-specific constants. */ +# define RANK768 3 +static const int kDU768 = 10; +static const int kDV768 = 4; +# define RANK1024 4 +static const int kDU1024 = 11; +static const int kDV1024 = 5; + +static ossl_inline size_t compressed_vector_size(int rank) +{ + return (rank == RANK768 ? kDU768 : kDU1024) * (size_t)rank * + DEGREE / 8; +} + +static ossl_inline size_t ciphertext_size(int rank) +{ + return compressed_vector_size(rank) + + (rank == RANK768 ? kDV768 : kDV1024) * DEGREE / 8; +} + +typedef struct scalar { + /* On every function entry and exit, 0 <= c < kPrime. */ + uint16_t c[DEGREE]; +} scalar; + +/* TODO(ML-KEM): possibly rename vector768 to allow for other algs */ +typedef struct vector { + scalar v[RANK768]; +} vector; + +/* TODO(ML-KEM): possibly rename matrix768 to allow for other algs */ +typedef struct matrix { + scalar v[RANK768][RANK768]; +} matrix; + +typedef struct public_key_RANK768 { + vector t; + uint8_t rho[32]; + uint8_t public_key_hash[32]; + matrix m; +} public_key_RANK768; + +typedef struct private_key_RANK768 { + struct public_key_RANK768 pub; + vector s; + uint8_t fo_failure_secret[32]; +} private_key_RANK768; + +static ossl_inline size_t encoded_vector_size(int rank) +{ + return (kLog2Prime * DEGREE / 8) * (size_t)rank; +} + +static ossl_inline size_t encoded_public_key_size(int rank) +{ + return encoded_vector_size(rank) + /* sizeof(rho)= */ 32; +} + +/* + * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy + * necessary to encapsulate a secret. The entropy will be leaked to the + * decapsulating party. + */ +# define MLKEM_ENCAP_ENTROPY 32 + +/* MD&XOF handles */ + +/* Cache mgmt as per https://github.com/openssl/private/issues/700 */ + +ossl_mlkem_ctx *ossl_mlkem_newctx(OSSL_LIB_CTX *libctx, const char *properties) +{ + ossl_mlkem_ctx *nctx = OPENSSL_zalloc(sizeof(ossl_mlkem_ctx)); + + /* replacing static asserts: */ + if (nctx == NULL + || (OSSL_MLKEM768_SHARED_SECRET_BYTES != 32) + || (sizeof(unsigned int) < sizeof (uint32_t)) + || (sizeof(struct ossl_mlkem768_public_key) < + sizeof(struct public_key_RANK768)) + || (sizeof(struct ossl_mlkem768_private_key) < + sizeof(struct private_key_RANK768)) + || (encoded_public_key_size(RANK768) != OSSL_MLKEM768_PUBLIC_KEY_BYTES) + || (encoded_public_key_size(RANK1024) != OSSL_MLKEM1024_PUBLIC_KEY_BYTES) + || (ciphertext_size(RANK768) != OSSL_MLKEM768_CIPHERTEXT_BYTES) + || (ciphertext_size(RANK1024) != OSSL_MLKEM1024_CIPHERTEXT_BYTES)) + goto err; + + nctx->shake128_cache = EVP_MD_fetch(libctx, "SHAKE128", properties); + nctx->shake256_cache = EVP_MD_fetch(libctx, "SHAKE256", properties); + nctx->sha3_256_cache = EVP_MD_fetch(libctx, "SHA3-256", properties); + nctx->sha3_512_cache = EVP_MD_fetch(libctx, "SHA3-512", properties); + nctx->libctx = libctx; + if (properties != NULL) + if ((nctx->properties = OPENSSL_strdup(properties)) == NULL) + goto err; + if (nctx->shake128_cache == NULL || nctx->shake256_cache == NULL || + nctx->sha3_256_cache == NULL || nctx->sha3_512_cache == NULL) + goto err; + return nctx; + +err: + ERR_raise(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR); + ossl_mlkem_ctx_free(nctx); + return NULL; +} + +void ossl_mlkem_ctx_free(ossl_mlkem_ctx *ctx) +{ + if (ctx != NULL) { + EVP_MD_free(ctx->shake128_cache); + EVP_MD_free(ctx->shake256_cache); + EVP_MD_free(ctx->sha3_256_cache); + EVP_MD_free(ctx->sha3_512_cache); + OPENSSL_free(ctx->properties); + } + OPENSSL_free(ctx); +} + +/* + * single_keccak hashes |in_len| bytes from |in| and writes |out_len| bytes + * of output to |out|. If the |md| specifies a fixed-output function, like + * SHA3-256, then |out_len| must be the correct length for that function. + */ +static int single_keccak(uint8_t *out, size_t out_len, + const uint8_t *in, size_t in_len, + EVP_MD *md) +{ + EVP_MD_CTX *mdctx; + int ret = 0; + + mdctx = EVP_MD_CTX_new(); + if (mdctx == NULL + || !EVP_DigestInit_ex(mdctx, md, NULL) + || !EVP_DigestUpdate(mdctx, in, in_len)) + return 0; + + if (EVP_MD_xof(md)) + ret = EVP_DigestFinalXOF(mdctx, out, out_len); + else + ret = EVP_DigestFinal_ex(mdctx, out, NULL); + EVP_MD_CTX_free(mdctx); + + return ret; +} + +/* TODO(ML-KEM) revisit utility of this function/remove eventually */ +static void print_hex(const uint8_t *data, int len, const char *msg) +{ +# ifndef NDEBUG + if (msg) + printf("%s: \n", msg); + BIO_dump_fp(stdout, data, len); + fflush(0); +# endif +} + +/* + * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy + * necessary to encapsulate a secret. The entropy will be leaked to the + * decapsulating party. + */ +# define MLKEM_ENCAP_ENTROPY 32 + +/* See https://csrc.nist.gov/pubs/fips/203/final */ +static int prf(uint8_t *out, size_t out_len, const uint8_t in[33], + ossl_mlkem_ctx *mlkem_ctx) +{ + return single_keccak(out, out_len, in, 33, mlkem_ctx->shake256_cache); +} + +/* + * Section 4.1 + * uint8_t out[32] + */ +static int hash_h(uint8_t *out, const uint8_t *in, size_t len, + ossl_mlkem_ctx *mlkem_ctx) +{ + return single_keccak(out, 32, in, len, mlkem_ctx->sha3_256_cache); +} + +/* uint8_t out[64] */ +static int hash_g(uint8_t *out, const uint8_t *in, size_t len, + ossl_mlkem_ctx *mlkem_ctx) +{ + return single_keccak(out, 64, in, len, mlkem_ctx->sha3_512_cache); +} + +/* + * This is called `J` in the spec. + * uint8_t out[ossl_mlkem768_SHARED_SECRET_BYTES], + * const uint8_t failure_secret[32] + */ +static int kdf(uint8_t *out, + const uint8_t *failure_secret, const uint8_t *ciphertext, + size_t ciphertext_len, + ossl_mlkem_ctx *mlkem_ctx) +{ + EVP_MD_CTX *mdctx; + int ret = 0; + + mdctx = EVP_MD_CTX_new(); + if (mdctx == NULL + || !EVP_DigestInit_ex(mdctx, mlkem_ctx->shake256_cache, NULL) + || !EVP_DigestUpdate(mdctx, failure_secret, 32) + || !EVP_DigestUpdate(mdctx, ciphertext, ciphertext_len) + || !EVP_DigestFinalXOF(mdctx, out, OSSL_MLKEM768_SHARED_SECRET_BYTES)) + goto end; + + ret = 1; + +end: + EVP_MD_CTX_free(mdctx); + return ret; +} + +/* + * This bit of Python will be referenced in some of the following comments: + * + * p = 3329 + * + * def bitreverse(i): + * ret = 0 + * for n in range(7): + * bit = i & 1 + * ret <<= 1 + * ret |= bit + * i >>= 1 + * return ret + * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] + */ + +static const uint16_t kNTTRoots[128] = { + 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, + 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, + 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, + 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, + 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, + 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, +}; + +/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ +static const uint16_t kInverseNTTRoots[128] = { + 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, + 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, + 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, + 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, + 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, + 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, + 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, + 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, + 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, + 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, + 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, +}; + +/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ +static const uint16_t kModRoots[128] = { + 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, + 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, + 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, + 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, + 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, + 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, + 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, + 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, + 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, + 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, + 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, +}; + +/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ +static uint16_t reduce_once(uint16_t x) +{ + const uint16_t subtracted = x - kPrime; + uint16_t mask = 0u - (subtracted >> 15); + + assert(x < 2 * kPrime); + /* + * On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of + * ML-KEM overall and Clang still produces constant-time code using `csel`. On + * other platforms & compilers on godbolt that we care about, this code also + * produces constant-time output. + */ + return (mask & x) | (~mask & subtracted); +} + +/* + * constant time reduce x mod kPrime using Barrett reduction. x must be less + * than kPrime + 2xkPrime^2. + */ +static uint16_t reduce(uint32_t x) +{ + uint64_t product = (uint64_t)x * kBarrettMultiplier; + uint32_t quotient = (uint32_t)(product >> kBarrettShift); + uint32_t remainder = x - quotient * kPrime; + + assert(x < kPrime + 2u * kPrime * kPrime); + return reduce_once(remainder); +} + +static void scalar_zero(scalar *out) +{ + memset(out, 0, sizeof(*out)); +} + +static void vector_zero(vector *out) +{ + memset(out->v, 0, sizeof(scalar) * RANK768); +} + +/* + * In place number theoretic transform of a given scalar. + * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this + * transform leaves off the last iteration of the usual FFT code, with the 128 + * relevant roots of unity being stored in |kNTTRoots|. This means the output + * should be seen as 128 elements in GF(3329^2), with the coefficients of the + * elements being consecutive entries in |s->c|. + */ +static void scalar_ntt(scalar *s) +{ + int offset = DEGREE; + int k, step, i, j; + uint32_t step_root; + uint16_t odd, even; + + /* + * `int` is used here because using `size_t` throughout caused a ~5% slowdown + * with Clang 14 on Aarch64. + */ + for (step = 1; step < DEGREE / 2; step <<= 1) { + offset >>= 1; + k = 0; + for (i = 0; i < step; i++) { + step_root = kNTTRoots[i + step]; + for (j = k; j < k + offset; j++) { + odd = reduce(step_root * s->c[j + offset]); + even = s->c[j]; + s->c[j] = reduce_once(odd + even); + s->c[j + offset] = reduce_once(even - odd + kPrime); + } + k += 2 * offset; + } + } +} + +static void vector_ntt(vector *a) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_ntt(&a->v[i]); +} + +/* + * In place inverse number theoretic transform of a given scalar, with pairs of + * entries of s->v being interpreted as elements of GF(3329^2). Just as with the + * number theoretic transform, this leaves off the first step of the normal iFFT + * to account for the fact that 3329 does not have a 512th root of unity, using + * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. + */ +static void scalar_inverse_ntt(scalar *s) +{ + int step = DEGREE / 2; + int offset, k, i, j; + uint32_t step_root; + uint16_t odd, even; + + /* + * `int` is used here because using `size_t` throughout caused a ~5% slowdown + * with Clang 14 on Aarch64. + */ + for (offset = 2; offset < DEGREE; offset <<= 1) { + step >>= 1; + k = 0; + for (i = 0; i < step; i++) { + step_root = kInverseNTTRoots[i + step]; + for (j = k; j < k + offset; j++) { + odd = s->c[j + offset]; + even = s->c[j]; + s->c[j] = reduce_once(odd + even); + s->c[j + offset] = reduce(step_root * (even - odd + kPrime)); + } + k += 2 * offset; + } + } + for (i = 0; i < DEGREE; i++) + s->c[i] = reduce(s->c[i] * kInverseDegree); +} + +static void vector_inverse_ntt(vector *a) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_inverse_ntt(&a->v[i]); +} + +static void scalar_add(scalar *lhs, const scalar *rhs) +{ + int i; + + for (i = 0; i < DEGREE; i++) + lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); +} + +static void scalar_sub(scalar *lhs, const scalar *rhs) +{ + int i; + + for (i = 0; i < DEGREE; i++) + lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); +} + +/* + * Multiplying two scalars in the number theoretically transformed state. Since + * 3329 does not have a 512th root of unity, this means we have to interpret + * the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2 + * - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is + * stored in the precomputed |kModRoots| table. Note that our Barrett transform + * only allows us to multipy two reduced numbers together, so we need some + * intermediate reduction steps, even if an uint64_t could hold 3 multiplied + * numbers. + */ +static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) +{ + int i; + uint32_t real_real, img_img, real_img, img_real; + + for (i = 0; i < DEGREE / 2; i++) { + real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; + img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1]; + real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; + img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; + out->c[2 * i] = + reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]); + out->c[2 * i + 1] = reduce(img_real + real_img); + } +} + +static void vector_add(vector *lhs, const vector *rhs) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_add(&lhs->v[i], &rhs->v[i]); +} + +static void matrix_mult(vector *out, const matrix *m, + const vector *a) +{ + int i, j; + + vector_zero(out); + for (i = 0; i < RANK768; i++) { + for (j = 0; j < RANK768; j++) { + scalar product; + + scalar_mult(&product, &m->v[i][j], &a->v[j]); + scalar_add(&out->v[i], &product); + } + } +} + +static void matrix_mult_transpose(vector *out, const matrix *m, + const vector *a) +{ + int i, j; + + vector_zero(out); + for (i = 0; i < RANK768; i++) { + for (j = 0; j < RANK768; j++) { + scalar product; + + scalar_mult(&product, &m->v[j][i], &a->v[j]); + scalar_add(&out->v[i], &product); + } + } +} + +static void scalar_inner_product(scalar *out, const vector *lhs, + const vector *rhs) +{ + int i; + + scalar_zero(out); + for (i = 0; i < RANK768; i++) { + scalar product; + + scalar_mult(&product, &lhs->v[i], &rhs->v[i]); + scalar_add(out, &product); + } +} + +/* + * Algorithm 6 from the spec. Rejection samples a Keccak stream to get + * uniformly distributed elements. This is used for matrix expansion and only + * operates on public inputs. + */ +static int scalar_from_keccak_vartime(scalar *out, EVP_MD_CTX *mdctx) +{ + int done = 0; + uint8_t block[168]; + size_t i; + uint16_t d1, d2; + + while (done < DEGREE) { + if (!EVP_DigestSqueeze(mdctx, block, sizeof(block))) + return 0; + for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { + d1 = block[i] + 256 * (block[i + 1] % 16); + d2 = block[i + 1] / 16 + 16 * block[i + 2]; + if (d1 < kPrime) + out->c[done++] = d1; + if (d2 < kPrime && done < DEGREE) + out->c[done++] = d2; + } + } + return 1; +} + +/* + * Algorithm 7 from the spec, with eta fixed to two and the PRF call + * included. Creates binominally distributed elements by sampling 2*|eta| bits, + * and setting the coefficient to the count of the first bits minus the count of + * the second bits, resulting in a centered binomial distribution. Since eta is + * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, + * and 0 with probability 3/8. + */ +static +int scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, + const uint8_t input[33], + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t entropy[128]; + int i; + uint8_t byte; + uint16_t value; + + assert(sizeof(entropy) == 2 * /* kEta= */ 2 * DEGREE / 8); + if (!prf(entropy, sizeof(entropy), input, mlkem_ctx)) + return 0; + for (i = 0; i < DEGREE; i += 2) { + byte = entropy[i / 2]; + value = kPrime; + value += (byte & 1) + ((byte >> 1) & 1); + value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); + out->c[i] = reduce_once(value); + byte >>= 4; + value = kPrime; + value += (byte & 1) + ((byte >> 1) & 1); + value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); + out->c[i + 1] = reduce_once(value); + } + return 1; +} + +/* + * Generates a secret vector by using + * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed + * appending and incrementing |counter| for entry of the vector. + */ +static int vector_generate_secret_eta_2(vector *out, uint8_t *counter, + const uint8_t seed[32], + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t input[33]; + int i; + + memcpy(input, seed, 32); + for (i = 0; i < RANK768; i++) { + input[32] = (*counter)++; + if (!scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], + input, mlkem_ctx)) + return 0; + } + return 1; +} + +/* Expands the matrix of a seed for key generation and for encaps-CPA. */ +static int matrix_expand(matrix *out, const uint8_t rho[32], + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t input[34]; + int i, j, ret = 0; + EVP_MD_CTX *mdctx = EVP_MD_CTX_new(); + + if (mdctx == NULL) + goto end; + memcpy(input, rho, 32); + for (i = 0; i < RANK768; i++) { + for (j = 0; j < RANK768; j++) { + input[32] = i; + input[33] = j; + if (!EVP_DigestInit_ex(mdctx, mlkem_ctx->shake128_cache, NULL) + || !EVP_DigestUpdate(mdctx, input, sizeof(input)) + || !scalar_from_keccak_vartime(&out->v[i][j], mdctx)) + goto end; + } + } + + ret = 1; +end: + EVP_MD_CTX_free(mdctx); + return ret; +} + +static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, + 0x1f, 0x3f, 0x7f, 0xff}; + +static void scalar_encode(uint8_t *out, const scalar *s, int bits) +{ + uint8_t out_byte = 0; + int out_byte_bits = 0; + int i, element_bits_done, chunk_bits, out_bits_remaining; + uint16_t element; + + assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); + for (i = 0; i < DEGREE; i++) { + element = s->c[i]; + element_bits_done = 0; + while (element_bits_done < bits) { + chunk_bits = bits - element_bits_done; + out_bits_remaining = 8 - out_byte_bits; + if (chunk_bits >= out_bits_remaining) { + chunk_bits = out_bits_remaining; + out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits; + *out = out_byte; + out++; + out_byte_bits = 0; + out_byte = 0; + } else { + out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits; + out_byte_bits += chunk_bits; + } + element_bits_done += chunk_bits; + element >>= chunk_bits; + } + } + if (out_byte_bits > 0) + *out = out_byte; +} + +/* + * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. + * uint8_t out[32] + */ +static void scalar_encode_1(uint8_t *out, const scalar *s) +{ + int i, j; + uint8_t out_byte; + + for (i = 0; i < DEGREE; i += 8) { + out_byte = 0; + for (j = 0; j < 8; j++) + out_byte |= (s->c[i + j] & 1) << j; + *out = out_byte; + out++; + } +} + +/* + * Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256 + * (DEGREE) is divisible by 8, the individual vector entries will always fill a + * whole number of bytes, so we do not need to worry about bit packing here. + */ +static void vector_encode(uint8_t *out, const vector *a, int bits) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits); +} + +/* + * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in + * |out|. It returns one on success and zero if any parsed value is >= + * |kPrime|. + */ +static int scalar_decode(scalar *out, const uint8_t *in, int bits) +{ + uint8_t in_byte = 0; + int in_byte_bits_left = 0; + int i, element_bits_done, chunk_bits; + uint16_t element; + + if (!ossl_assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1)) + return 0; + for (i = 0; i < DEGREE; i++) { + element = 0; + element_bits_done = 0; + while (element_bits_done < bits) { + if (in_byte_bits_left == 0) { + in_byte = *in; + in++; + in_byte_bits_left = 8; + } + chunk_bits = bits - element_bits_done; + if (chunk_bits > in_byte_bits_left) + chunk_bits = in_byte_bits_left; + element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done; + in_byte_bits_left -= chunk_bits; + in_byte >>= chunk_bits; + element_bits_done += chunk_bits; + } + if (element >= kPrime) + return 0; + out->c[i] = element; + } + return 1; +} + +/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ +static void scalar_decode_1(scalar *out, const uint8_t in[32]) +{ + int i, j; + uint8_t in_byte; + + for (i = 0; i < DEGREE; i += 8) { + in_byte = *in; + in++; + for (j = 0; j < 8; j++) { + out->c[i + j] = in_byte & 1; + in_byte >>= 1; + } + } +} + +/* + * Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on + * success or zero if any parsed value is >= |kPrime|. + */ +static int vector_decode(vector *out, const uint8_t *in, int bits) +{ + int i; + + for (i = 0; i < RANK768; i++) { + if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) + return 0; + } + return 1; +} + +/* + * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping + * numbers close to each other together. The formula used is + * round(2^|bits|/kPrime*x) mod 2^|bits|. + * Uses Barrett reduction to achieve constant time. Since we need both the + * remainder (for rounding) and the quotient (as the result), we cannot use + * |reduce| here, but need to do the Barrett reduction directly. + */ +static uint16_t compress(uint16_t x, int bits) +{ + uint32_t shifted = (uint32_t)x << bits; + uint64_t product = (uint64_t)shifted * kBarrettMultiplier; + uint32_t quotient = (uint32_t)(product >> kBarrettShift); + uint32_t remainder = shifted - quotient * kPrime; + + /* + * Adjust the quotient to round correctly: + * 0 <= remainder <= kHalfPrime round to 0 + * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 + * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 + */ + assert(remainder < 2u * kPrime); + quotient += 1 & constant_time_lt_32(kHalfPrime, remainder); + quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder); + return quotient & ((1 << bits) - 1); +} + +/* + * Decompresses |x| by using an equi-distant representative. The formula is + * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to + * implement this logic using only bit operations. + */ +static uint16_t decompress(uint16_t x, int bits) +{ + uint32_t product = (uint32_t)x * kPrime; + uint32_t power = 1 << bits; + /* This is |product| % power, since |power| is a power of 2. */ + uint32_t remainder = product & (power - 1); + /* This is |product| / power, since |power| is a power of 2. */ + uint32_t lower = product >> bits; + + /* + * The rounding logic works since the first half of numbers mod |power| have a + * 0 as first bit, and the second half has a 1 as first bit, since |power| is + * a power of 2. As a 12 bit number, |remainder| is always positive, so we + * will shift in 0s for a right shift. + */ + return lower + (remainder >> (bits - 1)); +} + +static void scalar_compress(scalar *s, int bits) +{ + int i; + + for (i = 0; i < DEGREE; i++) + s->c[i] = compress(s->c[i], bits); +} + +static void scalar_decompress(scalar *s, int bits) +{ + int i; + + for (i = 0; i < DEGREE; i++) + s->c[i] = decompress(s->c[i], bits); +} + +static void vector_compress(vector *a, int bits) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_compress(&a->v[i], bits); +} + +static void vector_decompress(vector *a, int bits) +{ + int i; + + for (i = 0; i < RANK768; i++) + scalar_decompress(&a->v[i], bits); +} + +static +public_key_RANK768 *public_key_768_from_external(const ossl_mlkem768_public_key *external) +{ + return (struct public_key_RANK768 *)external; +} + +static +private_key_RANK768 *private_key_768_from_external(const ossl_mlkem768_private_key *external) +{ + return (struct private_key_RANK768 *)external; +} + +static int mlkem_marshal_public_key(uint8_t *out, + const struct public_key_RANK768 *pub) +{ + /* + * replace CBB logic with straight copy to out and memcpy of rho at tail end + * TODO(ML-KEM): Check this is OK to protect incorrect buffer(sizes) passed + * possibly use WPACKET? + */ + vector_encode(out, &pub->t, kLog2Prime); + memcpy(out + encoded_vector_size(RANK768), pub->rho, sizeof(pub->rho)); + return 1; +} + +int ossl_mlkem768_recreate_public_key(const uint8_t *encoded_public_key, + ossl_mlkem768_public_key *ext_pub, + ossl_mlkem_ctx *mlkem_ctx) +{ + struct public_key_RANK768 *pub = public_key_768_from_external(ext_pub); + + print_hex(encoded_public_key, OSSL_MLKEM768_PUBLIC_KEY_BYTES, "Encoded key"); + if (!vector_decode(&pub->t, encoded_public_key, kLog2Prime)) + return 0; + memcpy(pub->rho, encoded_public_key + encoded_vector_size(RANK768), sizeof(pub->rho)); + if (!matrix_expand(&pub->m, pub->rho, mlkem_ctx) + || (!hash_h(pub->public_key_hash, encoded_public_key, + encoded_public_key_size(RANK768), mlkem_ctx))) + return 0; + print_hex((uint8_t *)pub, sizeof(public_key_RANK768), "recreated PK"); + return 1; +} + +static int mlkem_generate_key_external_seed(uint8_t *out_encoded_public_key, + private_key_RANK768 *priv, + const uint8_t *seed, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t augmented_seed[33]; + uint8_t hashed[64]; + const uint8_t *const rho = hashed; + const uint8_t *const sigma = hashed + 32; + uint8_t counter = 0; + vector error; + + if (mlkem_ctx == NULL) + return 0; + + memcpy(augmented_seed, seed, 32); + augmented_seed[32] = RANK768; + if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mlkem_ctx)) + return 0; + memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho)); + if (!matrix_expand(&priv->pub.m, rho, mlkem_ctx) + || (!vector_generate_secret_eta_2(&priv->s, &counter, sigma, mlkem_ctx))) + return 0; + vector_ntt(&priv->s); + if (!vector_generate_secret_eta_2(&error, &counter, sigma, mlkem_ctx)) + return 0; + vector_ntt(&error); + matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s); + vector_add(&priv->pub.t, &error); + if (!mlkem_marshal_public_key(out_encoded_public_key, &priv->pub) + || !hash_h(priv->pub.public_key_hash, out_encoded_public_key, + encoded_public_key_size(RANK768), mlkem_ctx)) + return 0; + memcpy(priv->fo_failure_secret, seed + 32, 32); + return 1; +} + +static +int ossl_mlkem768_generate_key_external_seed(uint8_t *out_encoded_public_key, + ossl_mlkem768_private_key *out_private_key, + const uint8_t *seed, + ossl_mlkem_ctx *mlkem_ctx) +{ + private_key_RANK768 *priv = NULL; + + priv = private_key_768_from_external(out_private_key); + return mlkem_generate_key_external_seed(out_encoded_public_key, priv, seed, mlkem_ctx); +} + +int ossl_mlkem768_generate_key(uint8_t *out_encoded_public_key, + uint8_t *optional_out_seed, + ossl_mlkem768_private_key *out_private_key, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t seed[MLKEM_SEED_BYTES]; + + if (mlkem_ctx == NULL) + return 0; + + /* TODO(ML-KEM): Review requested randomness strength */ + if (RAND_priv_bytes_ex(mlkem_ctx->libctx, seed, sizeof(seed), 256) == 1) { + if (optional_out_seed) + memcpy(optional_out_seed, seed, sizeof(seed)); + return ossl_mlkem768_generate_key_external_seed(out_encoded_public_key, + out_private_key, + seed, mlkem_ctx); + } + return 0; +} + +int ossl_mlkem768_private_key_from_seed(ossl_mlkem768_private_key *out_private_key, + const uint8_t *seed, size_t seed_len, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t public_key_bytes[OSSL_MLKEM768_PUBLIC_KEY_BYTES]; + + if (seed_len != MLKEM_SEED_BYTES) + return 0; + ossl_mlkem768_generate_key_external_seed(public_key_bytes, out_private_key, + seed, mlkem_ctx); + return 1; +} + +int ossl_mlkem768_public_from_private(ossl_mlkem768_public_key *out_public_key, + const ossl_mlkem768_private_key *private_key) +{ + struct public_key_RANK768 *const pub = public_key_768_from_external(out_public_key); + const struct private_key_RANK768 *const priv = + private_key_768_from_external(private_key); + + if (priv == NULL) + return 0; + *pub = priv->pub; + return 1; +} + +/* + * Encrypts a message with given randomness to + * the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this + * would not result in a CCA secure scheme, since lattice schemes are vulnerable + * to decryption failure oracles. + */ +static int encrypt_cpa(uint8_t *out, const struct public_key_RANK768 *pub, + const uint8_t *message, + const uint8_t *randomness, + ossl_mlkem_ctx *mlkem_ctx) +{ + int du = kDU768; + int dv = kDV768; + uint8_t counter = 0; + vector secret, error; + uint8_t input[33]; + scalar scalar_error; + vector u; + scalar v; + scalar expanded_message; + + if (!vector_generate_secret_eta_2(&secret, &counter, randomness, mlkem_ctx)) + return 0; + vector_ntt(&secret); + if (!vector_generate_secret_eta_2(&error, &counter, randomness, mlkem_ctx)) + return 0; + memcpy(input, randomness, 32); + input[32] = counter; + if (!scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input, mlkem_ctx)) + return 0; + matrix_mult(&u, &pub->m, &secret); + vector_inverse_ntt(&u); + vector_add(&u, &error); + scalar_inner_product(&v, &pub->t, &secret); + scalar_inverse_ntt(&v); + scalar_add(&v, &scalar_error); + scalar_decode_1(&expanded_message, message); + scalar_decompress(&expanded_message, 1); + scalar_add(&v, &expanded_message); + vector_compress(&u, du); + vector_encode(out, &u, du); + scalar_compress(&v, dv); + scalar_encode(out + compressed_vector_size(RANK768), &v, dv); + return 1; +} + +/* + * See section 6.2. + * entropy[MLKEM_ENCAP_ENTROPY]) + */ +static int mlkem_encap_external_entropy(uint8_t *out_ciphertext, + uint8_t *out_shared_secret, + const public_key_RANK768 *pub, + const uint8_t *entropy, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t input[64]; + uint8_t key_and_randomness[64]; + + memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); + memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash, + sizeof(input) - MLKEM_ENCAP_ENTROPY); + if (!hash_g(key_and_randomness, input, sizeof(input), mlkem_ctx) + || !encrypt_cpa(out_ciphertext, pub, entropy, + key_and_randomness + 32, mlkem_ctx)) + return 0; + memcpy(out_shared_secret, key_and_randomness, 32); + return 1; +} + +/* + * out_ciphertext[ossl_mlkem768_CIPHERTEXT_BYTES], + * out_shared_secret[ossl_mlkem768_SHARED_SECRET_BYTES], + * entropy[MLKEM_ENCAP_ENTROPY]) + */ +static +int ossl_mlkem768_encap_external_entropy(uint8_t *out_ciphertext, + uint8_t *out_shared_secret, + const ossl_mlkem768_public_key *public_key, + const uint8_t *entropy, + ossl_mlkem_ctx *mlkem_ctx) +{ + const struct public_key_RANK768 *pub = + public_key_768_from_external(public_key); + + return mlkem_encap_external_entropy(out_ciphertext, out_shared_secret, pub, + entropy, mlkem_ctx); +} + +/* Calls |ossl_mlkem768_encap_external_entropy| with random bytes from |RAND_bytes| */ +int ossl_mlkem768_encap(uint8_t *out_ciphertext, + uint8_t *out_shared_secret, + const ossl_mlkem768_public_key *public_key, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t entropy[MLKEM_ENCAP_ENTROPY]; + + if (mlkem_ctx == NULL) + return 0; + + /* TODO(ML-KEM): Review requested randomness strength */ + if (RAND_bytes_ex(mlkem_ctx->libctx, entropy, MLKEM_ENCAP_ENTROPY, 256) != 1 + || !ossl_mlkem768_encap_external_entropy(out_ciphertext, out_shared_secret, public_key, + entropy, mlkem_ctx)) + return 0; + print_hex((uint8_t *)public_key, sizeof(ossl_mlkem768_public_key), "PK"); + print_hex(out_shared_secret, OSSL_MLKEM768_SHARED_SECRET_BYTES, "SS2"); + print_hex(out_ciphertext, OSSL_MLKEM768_CIPHERTEXT_BYTES, "CT2"); + return 1; +} + +static void decrypt_cpa(uint8_t *out, const struct private_key_RANK768 *priv, + const uint8_t *ciphertext) +{ + int du = kDU768; + int dv = kDV768; + vector u; + scalar v, mask; + + vector_decode(&u, ciphertext, du); + vector_decompress(&u, du); + vector_ntt(&u); + scalar_decode(&v, ciphertext + compressed_vector_size(RANK768), dv); + scalar_decompress(&v, dv); + scalar_inner_product(&mask, &priv->s, &u); + scalar_inverse_ntt(&mask); + scalar_sub(&v, &mask); + scalar_compress(&v, 1); + scalar_encode_1(out, &v); +} + +/* See section 6.3 */ +static int mlkem_decap(uint8_t *out_shared_secret, + const uint8_t *ciphertext, + const struct private_key_RANK768 *priv, + ossl_mlkem_ctx *mlkem_ctx) +{ + uint8_t decrypted[64]; + uint8_t key_and_randomness[64]; + size_t ciphertext_len = ciphertext_size(RANK768); + /* TODO(ML-KEM): Maximum also applicable for other algs? */ + uint8_t expected_ciphertext[OSSL_MLKEM1024_CIPHERTEXT_BYTES]; + uint8_t failure_key[32]; + uint8_t mask; + int i; + + print_hex((uint8_t *)&priv->pub, sizeof(ossl_mlkem768_public_key), "PK1"); + print_hex(ciphertext, OSSL_MLKEM768_CIPHERTEXT_BYTES, "CT"); + decrypt_cpa(decrypted, priv, ciphertext); + memcpy(decrypted + 32, priv->pub.public_key_hash, + sizeof(decrypted) - 32); + if (!hash_g(key_and_randomness, decrypted, sizeof(decrypted), mlkem_ctx)) + return 0; + assert(ciphertext_len <= sizeof(expected_ciphertext)); + encrypt_cpa(expected_ciphertext, &priv->pub, decrypted, + key_and_randomness + 32, mlkem_ctx); + kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len, mlkem_ctx); + mask = constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, + expected_ciphertext, ciphertext_len), 0); + for (i = 0; i < OSSL_MLKEM768_SHARED_SECRET_BYTES; i++) + out_shared_secret[i] = constant_time_select_8(mask, + key_and_randomness[i], + failure_key[i]); + print_hex(out_shared_secret, OSSL_MLKEM768_SHARED_SECRET_BYTES, "SS"); + return 1; +} + +int ossl_mlkem768_decap(uint8_t *out_shared_secret, + const uint8_t *ciphertext, size_t ciphertext_len, + const ossl_mlkem768_private_key *private_key, + ossl_mlkem_ctx *mlkem_ctx) +{ + const struct private_key_RANK768 *priv; + + if (mlkem_ctx == NULL) + return 0; + + if (ciphertext_len != OSSL_MLKEM768_CIPHERTEXT_BYTES) { + /* TODO(ML-KEM): Review requested randomness strength */ + RAND_bytes_ex(mlkem_ctx->libctx, out_shared_secret, + OSSL_MLKEM768_SHARED_SECRET_BYTES, 256); + return 0; + } + priv = private_key_768_from_external(private_key); + return mlkem_decap(out_shared_secret, ciphertext, priv, mlkem_ctx); +} + +#endif /* OPENSSL_NO_MLKEM */ diff --git a/include/crypto/mlkem.h b/include/crypto/mlkem.h new file mode 100644 index 00000000000..3b4321c83bc --- /dev/null +++ b/include/crypto/mlkem.h @@ -0,0 +1,179 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +/* Copyright (c) 2024, Google Inc. */ + +#ifndef OPENSSL_HEADER_MLKEM_H +# define OPENSSL_HEADER_MLKEM_H + +# include +# include +# include + +# if defined(__cplusplus) +extern "C" { +# endif + +# ifndef OPENSSL_NO_MLKEM + + typedef struct ossl_mlkem_ctx { + EVP_MD *shake128_cache; + EVP_MD *shake256_cache; + EVP_MD *sha3_256_cache; + EVP_MD *sha3_512_cache; + OSSL_LIB_CTX *libctx; + char *properties; + } ossl_mlkem_ctx; + + /* General ctx functions */ + ossl_mlkem_ctx *ossl_mlkem_newctx(OSSL_LIB_CTX *libctx, const char *properties); + + void ossl_mlkem_ctx_free(ossl_mlkem_ctx *ctx); + + /* + * ML-KEM-768. + * + * This implements the Module-Lattice-Based Key-Encapsulation Mechanism from + * https://csrc.nist.gov/pubs/fips/203/final + */ + + /* + * ossl_mlkem768_public_key contains an ML-KEM-768 public key. The contents of this + * object should never leave the address space since the format is unstable. + */ + + /* TODO: Review alignment as per https://github.com/openssl/private/issues/702 */ + + typedef struct ossl_mlkem768_public_key { + union { + uint8_t bytes[512 * (3 + 9) + 32 + 32]; + uint16_t alignment; + } opaque; + } ossl_mlkem768_public_key; + + /* + * ossl_mlkem768_private_key contains an ML-KEM-768 private key. The contents of this + * object should never leave the address space since the format is unstable. + */ + typedef struct ossl_mlkem768_private_key { + union { + uint8_t bytes[512 * (3 + 3 + 9) + 32 + 32 + 32]; + uint16_t alignment; + } opaque; + } ossl_mlkem768_private_key; + +/* + * Parameters from FIPS 203 Section 8: Parameter Sets + * Reference: https://csrc.nist.gov/pubs/fips/203/final + */ + +# define OSSL_MLKEM768_SECURITY_BITS 192 + + /* + * OSSL_MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024 + * public key. + */ +# define OSSL_MLKEM1024_PUBLIC_KEY_BYTES 1568 + + /* + * OSSL_MLKEM768_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-768 + * public key. + */ +# define OSSL_MLKEM768_PUBLIC_KEY_BYTES 1184 + + /* MLKEM_SEED_BYTES is the number of bytes in an ML-KEM seed. */ +# define MLKEM_SEED_BYTES 64 + + /* + * ossl_mlkem768_generate_key generates a random public/private key pair, writes the + * encoded public key to |out_encoded_public_key| and sets |out_private_key| to + * the private key. If |optional_out_seed| is not NULL then the seed used to + * generate the private key is written to it as well. + * out_encoded_public_key must be allocated to store ossl_mlkem768_PUBLIC_KEY_BYTES + * optional_out_seed if not NULL must be allocated to store MLKEM_SEED_BYTES + */ + int ossl_mlkem768_generate_key(uint8_t *out_encoded_public_key, + uint8_t *optional_out_seed, + struct ossl_mlkem768_private_key *out_private_key, + ossl_mlkem_ctx *mlkem_ctx); + + /* + * ossl_mlkem768_private_key_from_seed derives a private key from a seed that was + * generated by |ossl_mlkem768_generate_key|. It fails and returns 0 if |seed_len| is + * incorrect, otherwise it writes |*out_private_key| and returns 1. + */ + int ossl_mlkem768_private_key_from_seed(ossl_mlkem768_private_key *out_private_key, + const uint8_t *seed, + size_t seed_len, + ossl_mlkem_ctx *mlkem_ctx); + + /* + * ossl_mlkem768_public_from_private sets |*out_public_key| to the public key that + * corresponds to |private_key|. (This is faster than parsing the output of + * |ossl_mlkem768_generate_key| if, for some reason, you need to encapsulate to a key + * that was just generated.) + */ + int ossl_mlkem768_public_from_private(ossl_mlkem768_public_key *out_public_key, + const ossl_mlkem768_private_key *private_key); + + /* ossl_mlkem1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */ +# define OSSL_MLKEM1024_CIPHERTEXT_BYTES 1568 + + /* ossl_mlkem768_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-768 ciphertext. */ +# define OSSL_MLKEM768_CIPHERTEXT_BYTES 1088 + + /* ossl_mlkem768_SHARED_SECRET_BYTES is the number of bytes in an ML-KEM shared secret. */ +# define OSSL_MLKEM768_SHARED_SECRET_BYTES 32 + + /* + * ossl_mlkem768_encap encrypts a random shared secret for |public_key|, writes the + * ciphertext to |out_ciphertext|, and writes the random shared secret to + * |out_shared_secret|. + * it is assumed out_ciphertext has been allocated ossl_mlkem768_CIPHERTEXT_BYTES bytes + * and out_shared_secret has been allocated MLKEM_SHARED_SECRET_BYTES bytes + */ + int ossl_mlkem768_encap(uint8_t *out_ciphertext, + uint8_t *out_shared_secret, + const ossl_mlkem768_public_key *public_key, + ossl_mlkem_ctx *mlkem_ctx); + + /* + * ossl_mlkem768_decap decrypts a shared secret from |ciphertext| using |private_key| + * and writes it to |out_shared_secret|. If |ciphertext_len| is incorrect it + * returns 0, otherwise it returns 1. If |ciphertext| is invalid (but of the + * correct length), |out_shared_secret| is filled with a key that will always be + * the same for the same |ciphertext| and |private_key|, but which appears to be + * random unless one has access to |private_key|. These alternatives occur in + * constant time. Any subsequent symmetric encryption using |out_shared_secret| + * must use an authenticated encryption scheme in order to discover the + * decapsulation failure. + * it is assumed out_shared_secret has been allocated MLKEM_SHARED_SECRET_BYTES bytes + */ + int ossl_mlkem768_decap(uint8_t *out_shared_secret, + const uint8_t *ciphertext, size_t ciphertext_len, + const ossl_mlkem768_private_key *private_key, + ossl_mlkem_ctx *mlkem_ctx); + + /* + * ossl_mlkem768_recreate_public_key recreates a fully formed ossl_mlkem768_public_key + * from an input |encoded_public_key| of size ossl_mlkem768_PUBLIC_KEY_BYTES. + * |pub| is expected to point to an allocated memory area of + * sizeof(ossl_mlkem768_public_key) + */ + int ossl_mlkem768_recreate_public_key(const uint8_t *encoded_public_key, + ossl_mlkem768_public_key *pub, + ossl_mlkem_ctx *mlkem_ctx); + +# endif /* OPENSSL_NO_MLKEM */ + +# if defined(__cplusplus) +} /* extern C */ +# endif + +#endif /* OPENSSL_HEADER_MLKEM_H */ diff --git a/include/internal/constant_time.h b/include/internal/constant_time.h index 1f480d84d88..2cb4bbc2cd1 100644 --- a/include/internal/constant_time.h +++ b/include/internal/constant_time.h @@ -42,7 +42,6 @@ static ossl_inline unsigned int constant_time_lt(unsigned int a, /* Convenience method for getting an 8-bit mask. */ static ossl_inline unsigned char constant_time_lt_8(unsigned int a, unsigned int b); - /* Convenience method for uint32_t. */ static ossl_inline uint32_t constant_time_lt_32(uint32_t a, uint32_t b); diff --git a/include/internal/tlsgroups.h b/include/internal/tlsgroups.h index 73fb53bc5ff..2507bb18877 100644 --- a/include/internal/tlsgroups.h +++ b/include/internal/tlsgroups.h @@ -56,5 +56,12 @@ # define OSSL_TLS_GROUP_ID_ffdhe4096 0x0102 # define OSSL_TLS_GROUP_ID_ffdhe6144 0x0103 # define OSSL_TLS_GROUP_ID_ffdhe8192 0x0104 + /* + * TODO(ML-KEM): Update to 513 as per IANA + * https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8 + * Not done yet to not break interop testing with OQS test server; change when that + * gets updated in line with whatever IANA eventually defines + */ +# define OSSL_TLS_GROUP_ID_mlkem768 0x0768 #endif diff --git a/providers/common/capabilities.c b/providers/common/capabilities.c index 78099ecf659..6ca2b39efeb 100644 --- a/providers/common/capabilities.c +++ b/providers/common/capabilities.c @@ -18,6 +18,7 @@ #include "internal/tlsgroups.h" #include "prov/providercommon.h" #include "internal/e_os.h" +#include "crypto/mlkem.h" /* If neither ec or dh is available then we have no TLS-GROUP capabilities */ #if !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH) @@ -28,73 +29,76 @@ typedef struct tls_group_constants_st { int maxtls; /* Maximum TLS version (or 0 for undefined) */ int mindtls; /* Minimum DTLS version, -1 unsupported */ int maxdtls; /* Maximum DTLS version (or 0 for undefined) */ + int is_kem; /* Indicates utility as KEM */ } TLS_GROUP_CONSTANTS; static const TLS_GROUP_CONSTANTS group_list[] = { { OSSL_TLS_GROUP_ID_sect163k1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect163r1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect163r2, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect193r1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect193r2, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect233k1, 112, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect233r1, 112, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect239k1, 112, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect283k1, 128, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect283r1, 128, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect409k1, 192, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect409r1, 192, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect571k1, 256, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_sect571r1, 256, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp160k1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp160r1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp160r2, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp192k1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp192r1, 80, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp224k1, 112, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp224r1, 112, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_secp256k1, 128, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, - { OSSL_TLS_GROUP_ID_secp256r1, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0 }, - { OSSL_TLS_GROUP_ID_secp384r1, 192, TLS1_VERSION, 0, DTLS1_VERSION, 0 }, - { OSSL_TLS_GROUP_ID_secp521r1, 256, TLS1_VERSION, 0, DTLS1_VERSION, 0 }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, + { OSSL_TLS_GROUP_ID_secp256r1, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 }, + { OSSL_TLS_GROUP_ID_secp384r1, 192, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 }, + { OSSL_TLS_GROUP_ID_secp521r1, 256, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 }, { OSSL_TLS_GROUP_ID_brainpoolP256r1, 128, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_brainpoolP384r1, 192, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, { OSSL_TLS_GROUP_ID_brainpoolP512r1, 256, TLS1_VERSION, TLS1_2_VERSION, - DTLS1_VERSION, DTLS1_2_VERSION }, - { OSSL_TLS_GROUP_ID_x25519, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0 }, - { OSSL_TLS_GROUP_ID_x448, 224, TLS1_VERSION, 0, DTLS1_VERSION, 0 }, - { OSSL_TLS_GROUP_ID_brainpoolP256r1_tls13, 128, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_brainpoolP384r1_tls13, 192, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_brainpoolP512r1_tls13, 256, TLS1_3_VERSION, 0, -1, -1 }, + DTLS1_VERSION, DTLS1_2_VERSION, 0 }, + { OSSL_TLS_GROUP_ID_x25519, 128, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 }, + { OSSL_TLS_GROUP_ID_x448, 224, TLS1_VERSION, 0, DTLS1_VERSION, 0, 0 }, + { OSSL_TLS_GROUP_ID_brainpoolP256r1_tls13, 128, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_brainpoolP384r1_tls13, 192, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_brainpoolP512r1_tls13, 256, TLS1_3_VERSION, 0, -1, -1, 0 }, /* Security bit values as given by BN_security_bits() */ - { OSSL_TLS_GROUP_ID_ffdhe2048, 112, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_ffdhe3072, 128, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_ffdhe4096, 128, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1 }, - { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1 }, + { OSSL_TLS_GROUP_ID_ffdhe2048, 112, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_ffdhe3072, 128, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_ffdhe4096, 128, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_ffdhe6144, 128, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_ffdhe8192, 192, TLS1_3_VERSION, 0, -1, -1, 0 }, + { OSSL_TLS_GROUP_ID_mlkem768, OSSL_MLKEM768_SECURITY_BITS, + TLS1_3_VERSION, 0, -1, -1, 1 }, }; #define TLS_GROUP_ENTRY(tlsname, realname, algorithm, idx) \ @@ -120,10 +124,12 @@ static const TLS_GROUP_CONSTANTS group_list[] = { (unsigned int *)&group_list[idx].mindtls), \ OSSL_PARAM_int(OSSL_CAPABILITY_TLS_GROUP_MAX_DTLS, \ (unsigned int *)&group_list[idx].maxdtls), \ + OSSL_PARAM_int(OSSL_CAPABILITY_TLS_GROUP_IS_KEM, \ + (unsigned int *)&group_list[idx].is_kem), \ OSSL_PARAM_END \ } -static const OSSL_PARAM param_group_list[][10] = { +static const OSSL_PARAM param_group_list[][11] = { # ifndef OPENSSL_NO_EC # ifndef OPENSSL_NO_EC2M TLS_GROUP_ENTRY("sect163k1", "sect163k1", "EC", 0), @@ -204,6 +210,8 @@ static const OSSL_PARAM param_group_list[][10] = { TLS_GROUP_ENTRY("ffdhe6144", "ffdhe6144", "DH", 36), TLS_GROUP_ENTRY("ffdhe8192", "ffdhe8192", "DH", 37), # endif + /* TODO(ML-KEM): Decide final name, e.g., ML-KEM768 or MLKEM768 */ + TLS_GROUP_ENTRY("MLKEM768", "MLKEM768", "ML-KEM-768", 38), }; #endif /* !defined(OPENSSL_NO_EC) || !defined(OPENSSL_NO_DH) */ diff --git a/providers/defltprov.c b/providers/defltprov.c index ccc1e8e7e9b..9b8781e532e 100644 --- a/providers/defltprov.c +++ b/providers/defltprov.c @@ -482,6 +482,7 @@ static const OSSL_ALGORITHM deflt_asym_kem[] = { # endif { PROV_NAMES_EC, "provider=default", ossl_ec_asym_kem_functions }, #endif + { PROV_NAMES_MLKEM768, "provider=default", ossl_mlkem768_asym_kem_functions }, { NULL, NULL, NULL } }; @@ -544,6 +545,8 @@ static const OSSL_ALGORITHM deflt_keymgmt[] = { { PROV_NAMES_SM2, "provider=default", ossl_sm2_keymgmt_functions, PROV_DESCS_SM2 }, #endif + { PROV_NAMES_MLKEM768, "provider=default", ossl_mlkem768_keymgmt_functions, + PROV_DESCS_MLKEM768 }, { NULL, NULL, NULL } }; diff --git a/providers/implementations/include/prov/implementations.h b/providers/implementations/include/prov/implementations.h index 3863d96a407..c156c3b2bf3 100644 --- a/providers/implementations/include/prov/implementations.h +++ b/providers/implementations/include/prov/implementations.h @@ -324,6 +324,7 @@ extern const OSSL_DISPATCH ossl_sm2_keymgmt_functions[]; extern const OSSL_DISPATCH ossl_ml_dsa_44_keymgmt_functions[]; extern const OSSL_DISPATCH ossl_ml_dsa_65_keymgmt_functions[]; extern const OSSL_DISPATCH ossl_ml_dsa_87_keymgmt_functions[]; +extern const OSSL_DISPATCH ossl_mlkem768_keymgmt_functions[]; /* Key Exchange */ extern const OSSL_DISPATCH ossl_dh_keyexch_functions[]; @@ -400,6 +401,7 @@ extern const OSSL_DISPATCH ossl_sm2_asym_cipher_functions[]; extern const OSSL_DISPATCH ossl_rsa_asym_kem_functions[]; extern const OSSL_DISPATCH ossl_ecx_asym_kem_functions[]; extern const OSSL_DISPATCH ossl_ec_asym_kem_functions[]; +extern const OSSL_DISPATCH ossl_mlkem768_asym_kem_functions[]; /* Encoders */ extern const OSSL_DISPATCH ossl_rsa_to_PKCS1_der_encoder_functions[]; diff --git a/providers/implementations/include/prov/mlkem.h b/providers/implementations/include/prov/mlkem.h new file mode 100644 index 00000000000..4eb4202f3fb --- /dev/null +++ b/providers/implementations/include/prov/mlkem.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +#ifndef OSSL_INTERNAL_MLKEM_H +# define OSSL_INTERNAL_MLKEM_H +# pragma once + +# ifndef OPENSSL_NO_MLKEM + +# include +# include + +# define MLKEM_KEY_TYPE_512 0 +# define MLKEM_KEY_TYPE_768 1 +# define MLKEY_KEY_TYPE_1024 2 + +typedef struct mlkem768_key_st { + int keytype; + ossl_mlkem768_private_key seckey; + ossl_mlkem768_public_key pubkey; + uint8_t *encoded_pubkey; + int pubkey_initialized; + int seckey_initialized; + ossl_mlkem_ctx *mlkem_ctx; + void *provctx; +} MLKEM768_KEY; + +# endif /* OPENSSL_NO_MLKEM */ + +#endif /* OSSL_INTERNAL_MLKEM_H */ diff --git a/providers/implementations/include/prov/names.h b/providers/implementations/include/prov/names.h index 9280be0bbea..e556be86bb4 100644 --- a/providers/implementations/include/prov/names.h +++ b/providers/implementations/include/prov/names.h @@ -390,3 +390,5 @@ #define PROV_DESCS_ML_DSA_65 "OpenSSL ML-DSA-65 implementation" #define PROV_NAMES_ML_DSA_87 "ML-DSA-87:2.16.840.1.101.3.4.3.19:id-ml-dsa-87" #define PROV_DESCS_ML_DSA_87 "OpenSSL ML-DSA-87 implementation" +#define PROV_NAMES_MLKEM768 "ML-KEM-768" +#define PROV_DESCS_MLKEM768 "OpenSSL ML-KEM-768 implementation" diff --git a/providers/implementations/kem/build.info b/providers/implementations/kem/build.info index 4a6a58ff654..b452323c219 100644 --- a/providers/implementations/kem/build.info +++ b/providers/implementations/kem/build.info @@ -4,6 +4,7 @@ $RSA_KEM_GOAL=../../libdefault.a ../../libfips.a $EC_KEM_GOAL=../../libdefault.a $TEMPLATE_KEM_GOAL=../../libtemplate.a +$ML_KEM_GOAL=../../libdefault.a SOURCE[$RSA_KEM_GOAL]=rsa_kem.c @@ -15,3 +16,4 @@ IF[{- !$disabled{ec} -}] ENDIF SOURCE[$TEMPLATE_KEM_GOAL]=template_kem.c +SOURCE[$ML_KEM_GOAL]=ml_kem.c diff --git a/providers/implementations/kem/ml_kem.c b/providers/implementations/kem/ml_kem.c new file mode 100644 index 00000000000..99a55dddfe8 --- /dev/null +++ b/providers/implementations/kem/ml_kem.c @@ -0,0 +1,210 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "prov/provider_ctx.h" +#include "prov/implementations.h" +#include "prov/securitycheck.h" +#include "prov/providercommon.h" +#include "prov/mlkem.h" + +#define BUFSIZE 1000 +#if defined(NDEBUG) || defined(OPENSSL_NO_STDIO) +/* TODO(ML-KEM) to remove or replace with TRACE */ +static void debug_print(char *fmt, ...) +{ +} +#else +static void debug_print(char *fmt, ...) +{ + char out[BUFSIZE]; + va_list argptr; + + va_start(argptr, fmt); + vsnprintf(out, BUFSIZE, fmt, argptr); + va_end(argptr); + if (getenv("TEMPLATEKM")) + fprintf(stderr, "TEMPLATE_KM: %s", out); +} +#endif + +typedef struct { + OSSL_LIB_CTX *libctx; + MLKEM768_KEY *key; + int op; +} PROV_MLKEM_CTX; + +static OSSL_FUNC_kem_newctx_fn mlkem_newctx; +static OSSL_FUNC_kem_encapsulate_init_fn mlkem_encapsulate_init; +static OSSL_FUNC_kem_encapsulate_fn mlkem_encapsulate; +static OSSL_FUNC_kem_decapsulate_init_fn mlkem_decapsulate_init; +static OSSL_FUNC_kem_decapsulate_fn mlkem_decapsulate; +static OSSL_FUNC_kem_freectx_fn mlkem_freectx; +static OSSL_FUNC_kem_set_ctx_params_fn mlkem_set_ctx_params; + +static void *mlkem_newctx(void *provctx) +{ + PROV_MLKEM_CTX *ctx = OPENSSL_zalloc(sizeof(*ctx)); + + debug_print("MLKEMKEM newctx called\n"); + if (ctx == NULL) + return NULL; + + ctx->libctx = PROV_LIBCTX_OF(provctx); + + debug_print("MLKEMKEM newctx returns %p\n", ctx); + return ctx; +} + +static void mlkem_freectx(void *vctx) +{ + PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx; + + debug_print("MLKEMKEM freectx %p\n", ctx); + OPENSSL_free(ctx); +} + +static int mlkem_init(void *vctx, int operation, void *vkey, void *vauth, + ossl_unused const OSSL_PARAM params[]) +{ + PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx; + MLKEM768_KEY *mlkemkey = vkey; + + debug_print("MLKEMKEM init %p / %p\n", ctx, mlkemkey); + if (!ossl_prov_is_running()) + return 0; + + if (mlkemkey->keytype != MLKEM_KEY_TYPE_768) + return 0; + + ctx->key = mlkemkey; + ctx->op = operation; + debug_print("MLKEMKEM init OK\n"); + return 1; +} + +static int mlkem_encapsulate_init(void *vctx, void *vkey, + const OSSL_PARAM params[]) +{ + return mlkem_init(vctx, EVP_PKEY_OP_ENCAPSULATE, vkey, NULL, params); +} + +static int mlkem_decapsulate_init(void *vctx, void *vkey, + const OSSL_PARAM params[]) +{ + return mlkem_init(vctx, EVP_PKEY_OP_DECAPSULATE, vkey, NULL, params); +} + +static int mlkem_set_ctx_params(void *vctx, const OSSL_PARAM params[]) +{ + PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx; + + debug_print("MLKEMKEM set ctx params %p\n", ctx); + if (ctx == NULL) + return 0; + if (params == NULL) + return 1; + + debug_print("MLKEMKEM set ctx params OK\n"); + return 1; +} + +static const OSSL_PARAM known_settable_mlkem_ctx_params[] = { + OSSL_PARAM_END +}; + +static const OSSL_PARAM *mlkem_settable_ctx_params(ossl_unused void *vctx, + ossl_unused void *provctx) +{ + return known_settable_mlkem_ctx_params; +} + +static int mlkem_encapsulate(void *vctx, unsigned char *out, size_t *outlen, + unsigned char *secret, size_t *secretlen) +{ + PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx; + int ret; + + debug_print("MLKEMKEM encaps %p to %p\n", ctx, out); + if (outlen != NULL) + *outlen = OSSL_MLKEM768_CIPHERTEXT_BYTES; + if (secretlen != NULL) + *secretlen = OSSL_MLKEM768_SHARED_SECRET_BYTES; + + if (out == NULL) { + debug_print("MLKEMKEM encaps outlens set to %ld and %ld\n", *outlen, *secretlen); + return 1; + } + + if (ctx->key == NULL + || ctx->key->keytype != MLKEM_KEY_TYPE_768 + || ctx->key->pubkey_initialized == 0 + || secret == NULL) + return 0; + + ret = ossl_mlkem768_encap(out, (uint8_t *)secret, &ctx->key->pubkey, ctx->key->mlkem_ctx); + + debug_print("MLKEMKEM encaps returns %d\n", ret); + return ret; +} + +static int mlkem_decapsulate(void *vctx, unsigned char *out, size_t *outlen, + const unsigned char *in, size_t inlen) +{ + PROV_MLKEM_CTX *ctx = (PROV_MLKEM_CTX *)vctx; + int ret; + + debug_print("MLKEMKEM decaps %p to %p\n", ctx, out); + debug_print("MLKEMKEM decaps inlen at %ld\n", inlen); + if (outlen != NULL) + *outlen = OSSL_MLKEM768_SHARED_SECRET_BYTES; + + if (out == NULL) { + debug_print("MLKEMKEM decaps outlen set to %ld \n", *outlen); + return 1; + } + + if (ctx->key == NULL + || ctx->key->keytype != MLKEM_KEY_TYPE_768 + || ctx->key->seckey_initialized == 0 + || in == NULL) + return 0; + + if (inlen != OSSL_MLKEM768_CIPHERTEXT_BYTES) + return 0; + + ret = ossl_mlkem768_decap((uint8_t *)out, (uint8_t *)in, inlen, &ctx->key->seckey, + ctx->key->mlkem_ctx); + + debug_print("MLKEMKEM decaps returns %d\n", ret); + return ret; +} + +const OSSL_DISPATCH ossl_mlkem768_asym_kem_functions[] = { + { OSSL_FUNC_KEM_NEWCTX, (void (*)(void))mlkem_newctx }, + { OSSL_FUNC_KEM_ENCAPSULATE_INIT, + (void (*)(void))mlkem_encapsulate_init }, + { OSSL_FUNC_KEM_ENCAPSULATE, (void (*)(void))mlkem_encapsulate }, + { OSSL_FUNC_KEM_DECAPSULATE_INIT, + (void (*)(void))mlkem_decapsulate_init }, + { OSSL_FUNC_KEM_DECAPSULATE, (void (*)(void))mlkem_decapsulate }, + { OSSL_FUNC_KEM_FREECTX, (void (*)(void))mlkem_freectx }, + { OSSL_FUNC_KEM_SET_CTX_PARAMS, + (void (*)(void))mlkem_set_ctx_params }, + { OSSL_FUNC_KEM_SETTABLE_CTX_PARAMS, + (void (*)(void))mlkem_settable_ctx_params }, + OSSL_DISPATCH_END +}; diff --git a/providers/implementations/keymgmt/build.info b/providers/implementations/keymgmt/build.info index edfb15f7e2d..b1ee39c4fe0 100644 --- a/providers/implementations/keymgmt/build.info +++ b/providers/implementations/keymgmt/build.info @@ -10,6 +10,7 @@ $MAC_GOAL=../../libdefault.a ../../libfips.a $RSA_GOAL=../../libdefault.a ../../libfips.a $TEMPLATE_GOAL=../../libtemplate.a $ML_DSA_GOAL=../../libdefault.a ../../libfips.a +$MLKEM_GOAL=../../libdefault.a IF[{- !$disabled{dh} -}] SOURCE[$DH_GOAL]=dh_kmgmt.c @@ -49,3 +50,4 @@ SOURCE[$TEMPLATE_GOAL]=template_kmgmt.c IF[{- !$disabled{'ml-dsa'} -}] SOURCE[$ML_DSA_GOAL]=ml_dsa_kmgmt.c ENDIF +SOURCE[$MLKEM_GOAL]=mlkem_kmgmt.c diff --git a/providers/implementations/keymgmt/mlkem_kmgmt.c b/providers/implementations/keymgmt/mlkem_kmgmt.c new file mode 100644 index 00000000000..02d7e5fdf09 --- /dev/null +++ b/providers/implementations/keymgmt/mlkem_kmgmt.c @@ -0,0 +1,590 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +#include +#include +#include +#include +#include +#include +#include +#include "internal/param_build_set.h" +#include +#include "prov/mlkem.h" +#include "prov/implementations.h" +#include "prov/providercommon.h" +#include "prov/provider_ctx.h" +#include "prov/securitycheck.h" +#include + +#define BUFSIZE 1000 +#if defined(NDEBUG) || defined(OPENSSL_NO_STDIO) +/* TODO(ML-KEM) to remove or replace with TRACE */ +static void debug_print(char *fmt, ...) +{ +} +#else +static void debug_print(char *fmt, ...) +{ + char out[BUFSIZE]; + va_list argptr; + + va_start(argptr, fmt); + vsnprintf(out, BUFSIZE, fmt, argptr); + va_end(argptr); + if (getenv("TEMPLATEKM")) + fprintf(stderr, "TEMPLATE_KM: %s", out); +} +#endif + +static void print_hex(const uint8_t *data, int len, const char *msg) +{ +#ifndef NDEBUG + if (msg) + printf("%s: \n", msg); + BIO_dump_fp(stdout, data, len); + printf("\n\n"); +#endif +} + +static OSSL_FUNC_keymgmt_new_fn mlkem_new; +static OSSL_FUNC_keymgmt_free_fn mlkem_free; +static OSSL_FUNC_keymgmt_gen_init_fn mlkem_gen_init; +static OSSL_FUNC_keymgmt_gen_fn mlkem_gen; +static OSSL_FUNC_keymgmt_gen_cleanup_fn mlkem_gen_cleanup; +static OSSL_FUNC_keymgmt_gen_set_params_fn mlkem_gen_set_params; +static OSSL_FUNC_keymgmt_gen_settable_params_fn mlkem_gen_settable_params; +static OSSL_FUNC_keymgmt_get_params_fn mlkem_get_params; +static OSSL_FUNC_keymgmt_gettable_params_fn mlkem_gettable_params; +static OSSL_FUNC_keymgmt_set_params_fn mlkem_set_params; +static OSSL_FUNC_keymgmt_settable_params_fn mlkem_settable_params; +static OSSL_FUNC_keymgmt_has_fn mlkem_has; +static OSSL_FUNC_keymgmt_match_fn mlkem_match; + +/* implement only when encode/decode logic becomes required/standardized: */ +#ifdef UNDEF +static OSSL_FUNC_keymgmt_import_fn mlkem_import; +static OSSL_FUNC_keymgmt_export_fn mlkem_export; +static OSSL_FUNC_keymgmt_import_types_fn mlkem_imexport_types; +static OSSL_FUNC_keymgmt_export_types_fn mlkem_imexport_types; +static OSSL_FUNC_keymgmt_dup_fn mlkem_dup; +#endif /* UNDEF */ + +struct mlkem_gen_ctx { + void *provctx; + int selection; +}; + +static void *mlkem_new(void *provctx) +{ + MLKEM768_KEY *key = NULL; + + debug_print("MLKEMKM new key req\n"); + if (!ossl_prov_is_running()) + return 0; + + key = OPENSSL_zalloc(sizeof(MLKEM768_KEY)); + if (key != NULL) { + key->keytype = MLKEM_KEY_TYPE_768; /* TODO(ML-KEM) any type */ + key->provctx = provctx; + /* + * ideally, this is a one-time allocation and ctx that should be within the + * provider context: OK to move it there to improve performance?? It would be + * the first algorithmspecific context stored: Feels weird (TODO(ML-KEM)). + */ + key->mlkem_ctx = ossl_mlkem_newctx(provctx == NULL ? NULL : PROV_LIBCTX_OF(provctx), NULL); + if (key->mlkem_ctx == NULL) { + OPENSSL_free(key); + key = NULL; + } + } + + debug_print("MLKEMKM new key = %p\n", key); + return key; +} + +static void mlkem_free(void *vkey) +{ + MLKEM768_KEY *mkey = (MLKEM768_KEY *)vkey; + + debug_print("MLKEMKM free key %p\n", mkey); + if (mkey == NULL) + return; + ossl_mlkem_ctx_free(mkey->mlkem_ctx); + OPENSSL_free(mkey->encoded_pubkey); + OPENSSL_free(mkey); +} + +static int mlkem_has(const void *keydata, int selection) +{ + const MLKEM768_KEY *key = keydata; + int ok = 0; + + debug_print("MLKEMKM has %p\n", key); + if (ossl_prov_is_running() && key != NULL) { + /* + * ML-KEM keys always have all the parameters they need (i.e. none). + * Therefore we always return with 1, if asked about parameters. + */ + ok = 1; + + if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) + ok = ok && key->pubkey_initialized == 1; + + if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) + ok = ok && key->seckey_initialized == 1; + } + debug_print("MLKEMKM has result %d\n", ok); + return ok; +} + +static int mlkem_match(const void *keydata1, const void *keydata2, int selection) +{ + const MLKEM768_KEY *key1 = keydata1; + const MLKEM768_KEY *key2 = keydata2; + int ok = 1; + + debug_print("MLKEMKM matching %p and %p\n", key1, key2); + if (!ossl_prov_is_running()) + return 0; + + if ((selection & OSSL_KEYMGMT_SELECT_DOMAIN_PARAMETERS) != 0) + ok = ok && key1->keytype == key2->keytype; + + /* TODO(ML-KEM) */ + debug_print("MLKEMKM matching for now NOT YET IMPLEMENTED\n"); + +/* TODO(ML-KEM) template code to be completed as and when needed: */ +#ifdef UNDEF + if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) { + int key_checked = 0; + + if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) { + const uint8_t *pa = key1->pubkey; + const uint8_t *pb = key2->pubkey; + + if (pa != NULL && pb != NULL) { + ok = ok + && key1->keytype == key2->keytype + && CRYPTO_memcmp(pa, pb, MLKEM768_PUBLICKEYBYTES) == 0; + key_checked = 1; + } + } + if (!key_checked + && (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) { + const uint8_t *pa = key1->seckey; + const uint8_t *pb = key2->seckey; + + if (pa != NULL && pb != NULL) { + ok = ok + && key1->keytype == key2->keytype + && CRYPTO_memcmp(pa, pb, MLKEM768_SECRETKEYBYTES) == 0; + key_checked = 1; + } + } + ok = ok && key_checked; + } +#endif /* UNDEF */ + debug_print("MLKEMKM match result %d\n", ok); + return ok; +} + +/* TODO(ML-KEM) as and when encode/decode becomes needed/standardized */ +#ifdef UNDEF +static int key_to_params(MLKEM768_KEY *key, OSSL_PARAM_BLD *tmpl, + OSSL_PARAM params[], int include_private) +{ + if (key == NULL) + return 0; + + if (key->keytype != MLKEM_KEY_TYPE_768) + return 0; + + if (!ossl_param_build_set_octet_string(tmpl, params, + OSSL_PKEY_PARAM_PUB_KEY, + key->pubkey, MLKEM768_PUBLICKEYBYTES)) + return 0; + + if (include_private + && key->seckey != NULL + && !ossl_param_build_set_octet_string(tmpl, params, + OSSL_PKEY_PARAM_PRIV_KEY, + key->seckey, MLKEM768_SECRETKEYBYTES)) + return 0; + + return 1; +} + +static int mlkem_export(void *key, int selection, OSSL_CALLBACK *param_cb, + void *cbarg) +{ + MLKEM768_KEY *mkey = key; + OSSL_PARAM_BLD *tmpl; + OSSL_PARAM *params = NULL; + int ret = 0; + + debug_print("MLKEMKM export %p\n", key); + if (!ossl_prov_is_running() || key == NULL) + return 0; + + if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0) + return 0; + + tmpl = OSSL_PARAM_BLD_new(); + if (tmpl == NULL) + return 0; + + if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) { + int include_private = ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0); + + if (!key_to_params(mkey, tmpl, NULL, include_private)) + goto err; + } + + params = OSSL_PARAM_BLD_to_param(tmpl); + if (params == NULL) + goto err; + + ret = param_cb(params, cbarg); + OSSL_PARAM_free(params); +err: + OSSL_PARAM_BLD_free(tmpl); + debug_print("MLKEMKM export result %d\n", ret); + return ret; +} + +static int ossl_mlkem_key_fromdata(MLKEM768_KEY *key, + const OSSL_PARAM params[], + int include_private) +{ + size_t privkeylen = 0, pubkeylen = 0; + const OSSL_PARAM *param_priv_key = NULL, *param_pub_key; + unsigned char *pubkey; + + if (key == NULL) + return 0; + + if (key->keytype != MLKEM_KEY_TYPE_768) + return 0; + + param_pub_key = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY); + if (include_private) + param_priv_key = + OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PRIV_KEY); + + if (param_pub_key == NULL && param_priv_key == NULL) + return 0; + + if (param_priv_key != NULL) { + if (!OSSL_PARAM_get_octet_string(param_priv_key, + (void **)&key->seckey, + MLKEM768_SECRETKEYBYTES, + &privkeylen)) + return 0; + if (privkeylen != MLKEM768_SECRETKEYBYTES) { + debug_print("sec key len mismatch in import: %ld vs %d: HOWCAN?\n", + privkeylen, MLKEM768_SECRETKEYBYTES); + OPENSSL_secure_clear_free(key->seckey, privkeylen); + key->seckey = NULL; + return 0; + } + } + + pubkey = key->pubkey; + if (param_pub_key != NULL + && !OSSL_PARAM_get_octet_string(param_pub_key, + (void **)&pubkey, + ossl_mlkem768_PUBLIC_KEY_BYTES, + &pubkeylen)) + return 0; + + if ((param_pub_key != NULL && pubkeylen != ossl_mlkem768_PUBLIC_KEY_BYTES)) { + debug_print("sec key len mismatch in import: %ld vs %d: HOWCAN?\n", + pubkeylen, ossl_mlkem768_PUBLIC_KEY_BYTES); + return 0; + } + + /* + * TBD if hybrid logic is not getting cleanly implemented in separate logic: + * reconstitute (only) classic part here + */ + + return 1; +} + +static int mlkem_import(void *key, int selection, const OSSL_PARAM params[]) +{ + MLKEM768_KEY *mkey = key; + int ok = 1; + int include_private; + + debug_print("MLKEMKM import %p\n", mkey); + if (!ossl_prov_is_running() || key == NULL) + return 0; + + if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0) + return 0; + + include_private = selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY ? 1 : 0; + ok = ok && ossl_mlkem_key_fromdata(mkey, params, include_private); + + debug_print("MLKEMKM import result %d\n", ok); + return ok; +} + +# define MLKEM768_KEY_TYPES() \ + OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PUB_KEY, NULL, 0), \ + OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0) + +static const OSSL_PARAM mlkem_key_types[] = { + MLKEM768_KEY_TYPES(), + OSSL_PARAM_END +}; + +static const OSSL_PARAM *mlkem_imexport_types(int selection) +{ + debug_print("MLKEMKM getting imexport types\n"); + if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) + return mlkem_key_types; + return NULL; +} + +#endif /* UNDEF */ + +static int mlkem_get_params(void *key, OSSL_PARAM params[]) +{ + MLKEM768_KEY *mkey = key; + OSSL_PARAM *p; + + debug_print("MLKEMKM get params %p\n", mkey); + if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_BITS)) != NULL + && !OSSL_PARAM_set_int(p, sizeof(ossl_mlkem768_private_key) * 8)) + return 0; + if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_BITS)) != NULL + && !OSSL_PARAM_set_int(p, OSSL_MLKEM768_SECURITY_BITS)) + return 0; + if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_MAX_SIZE)) != NULL + && !OSSL_PARAM_set_int(p, OSSL_MLKEM768_CIPHERTEXT_BYTES)) + return 0; + if ((p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY)) != NULL + && mkey->pubkey_initialized == 1) { + if (!OSSL_PARAM_set_octet_string(p, mkey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES)) + return 0; + debug_print("MLKEMKM got encoded public key of len %d\n", OSSL_MLKEM768_PUBLIC_KEY_BYTES); + print_hex(mkey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES, "enc PK"); + } + + debug_print("MLKEMKM get params OK\n"); + return 1; +} + +static const OSSL_PARAM mlkem_gettable_params_arr[] = { + OSSL_PARAM_int(OSSL_PKEY_PARAM_BITS, NULL), + OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_BITS, NULL), + OSSL_PARAM_int(OSSL_PKEY_PARAM_MAX_SIZE, NULL), + OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0), + OSSL_PARAM_END +}; + +static const OSSL_PARAM *mlkem_gettable_params(void *provctx) +{ + debug_print("MLKEMKM gettable params called\n"); + return mlkem_gettable_params_arr; +} + +static int mlkem_set_params(void *key, const OSSL_PARAM params[]) +{ + MLKEM768_KEY *mkey = key; + const OSSL_PARAM *p; + + debug_print("MLKEMKM set params called for %p\n", mkey); + if (params == NULL) + return 1; + + p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY); + if (p != NULL) { + size_t len_stored; + + if (p->data_size != OSSL_MLKEM768_PUBLIC_KEY_BYTES + || !OSSL_PARAM_get_octet_string(p, (void **)&mkey->encoded_pubkey, + OSSL_MLKEM768_PUBLIC_KEY_BYTES, + &len_stored)) + return 0; + debug_print("encoded pub key successfully stored with %ld bytes\n", len_stored); + if (!ossl_mlkem768_recreate_public_key(mkey->encoded_pubkey, &mkey->pubkey, + mkey->mlkem_ctx)) + return 0; + mkey->pubkey_initialized = 1; + } + + debug_print("MLKEMKM set params OK\n"); + return 1; +} + +static const OSSL_PARAM mlkem_settable_params_arr[] = { + OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0), + OSSL_PARAM_END +}; + +static const OSSL_PARAM *mlkem_settable_params(void *provctx) +{ + debug_print("MLKEMKM settable params called\n"); + return mlkem_settable_params_arr; +} + +static void *mlkem_gen_init(void *provctx, int selection, + const OSSL_PARAM params[]) +{ + struct mlkem_gen_ctx *gctx = NULL; + + debug_print("MLKEMKM gen init called for %p\n", provctx); + if (!ossl_prov_is_running()) + return NULL; + + if ((gctx = OPENSSL_zalloc(sizeof(*gctx))) != NULL) { + gctx->provctx = provctx; + gctx->selection = selection; + } + if (!mlkem_gen_set_params(gctx, params)) { + OPENSSL_free(gctx); + gctx = NULL; + } + debug_print("MLKEMKM gen init returns %p\n", gctx); + return gctx; +} + +static int mlkem_gen_set_params(void *genctx, const OSSL_PARAM params[]) +{ + struct mlkem_gen_ctx *gctx = genctx; + + if (gctx == NULL) + return 0; + + debug_print("MLKEMKM empty gen_set params called for %p\n", gctx); + return 1; +} + +static const OSSL_PARAM *mlkem_gen_settable_params(ossl_unused void *genctx, + ossl_unused void *provctx) +{ + static OSSL_PARAM settable[] = { + OSSL_PARAM_END + }; + return settable; +} + +static void *mlkem_gen(void *vctx, OSSL_CALLBACK *osslcb, void *cbarg) +{ + struct mlkem_gen_ctx *gctx = (struct mlkem_gen_ctx *)vctx; + MLKEM768_KEY *mkey; + + debug_print("MLKEMKM gen called for %p\n", gctx); + if (gctx == NULL) + return NULL; + + if ((mkey = mlkem_new(gctx->provctx)) == NULL) { + ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR); + return NULL; + } + + /* If we're doing parameter generation then we just return a blank key */ + if ((gctx->selection & OSSL_KEYMGMT_SELECT_KEYPAIR) == 0) { + debug_print("MLKEMKM gen returns blank %p\n", mkey); + return mkey; + } + + mkey->keytype = MLKEM_KEY_TYPE_768; + + if (mkey->encoded_pubkey == NULL) { + mkey->encoded_pubkey = OPENSSL_malloc(OSSL_MLKEM768_PUBLIC_KEY_BYTES); + if (mkey->encoded_pubkey == NULL) + goto err; + } + + if (!ossl_mlkem768_generate_key(mkey->encoded_pubkey, NULL, &mkey->seckey, + mkey->mlkem_ctx) + || !ossl_mlkem768_public_from_private(&mkey->pubkey, &mkey->seckey)) + goto err; + + mkey->seckey_initialized = 1; + mkey->pubkey_initialized = 1; + + debug_print("MLKEMKM gen returns set %p\n", mkey); + return mkey; + +err: + OPENSSL_free(mkey); + return NULL; +} + +static void mlkem_gen_cleanup(void *genctx) +{ + struct mlkem_gen_ctx *gctx = genctx; + + debug_print("MLKEMKM gen cleanup for %p\n", gctx); + OPENSSL_free(gctx); +} + +static void *mlkem_dup(const void *vsrckey, int selection) +{ + const MLKEM768_KEY *srckey = (const MLKEM768_KEY *)vsrckey; + MLKEM768_KEY *dstkey; + + debug_print("MLKEMKM dup called for %p\n", srckey); + if (!ossl_prov_is_running()) + return NULL; + + dstkey = mlkem_new(srckey->provctx); + if (dstkey == NULL) + return NULL; + + dstkey->keytype = srckey->keytype; + if (srckey->pubkey_initialized == 1 + && (selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) { + memcpy((void *)&dstkey->pubkey, (void *)&srckey->pubkey, sizeof(srckey->pubkey)); + dstkey->encoded_pubkey = OPENSSL_malloc(OSSL_MLKEM768_PUBLIC_KEY_BYTES); + if (srckey->encoded_pubkey != NULL) + memcpy(dstkey->encoded_pubkey, srckey->encoded_pubkey, OSSL_MLKEM768_PUBLIC_KEY_BYTES); + dstkey->pubkey_initialized = 1; + } + if (srckey->seckey_initialized == 1 + && (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) { + memcpy((void *)&dstkey->seckey, (void *)&srckey->seckey, sizeof(srckey->seckey)); + dstkey->seckey_initialized = 1; + } + + debug_print("MLKEMKM dup returns %p\n", dstkey); + return dstkey; +} + +const OSSL_DISPATCH ossl_mlkem768_keymgmt_functions[] = { + { OSSL_FUNC_KEYMGMT_NEW, (void (*)(void))mlkem_new }, + { OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))mlkem_free }, + { OSSL_FUNC_KEYMGMT_GET_PARAMS, (void (*) (void))mlkem_get_params }, + { OSSL_FUNC_KEYMGMT_GETTABLE_PARAMS, (void (*) (void))mlkem_gettable_params }, + { OSSL_FUNC_KEYMGMT_SET_PARAMS, (void (*) (void))mlkem_set_params }, + { OSSL_FUNC_KEYMGMT_SETTABLE_PARAMS, (void (*) (void))mlkem_settable_params }, + { OSSL_FUNC_KEYMGMT_HAS, (void (*)(void))mlkem_has }, + { OSSL_FUNC_KEYMGMT_MATCH, (void (*)(void))mlkem_match }, + { OSSL_FUNC_KEYMGMT_GEN_INIT, (void (*)(void))mlkem_gen_init }, + { OSSL_FUNC_KEYMGMT_GEN_SET_PARAMS, (void (*)(void))mlkem_gen_set_params }, + { OSSL_FUNC_KEYMGMT_GEN_SETTABLE_PARAMS, + (void (*)(void))mlkem_gen_settable_params }, + { OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))mlkem_gen }, + { OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))mlkem_gen_cleanup }, + { OSSL_FUNC_KEYMGMT_DUP, (void (*)(void))mlkem_dup }, + /* + * TODO(ML-KEM) don't do for now, see https://github.com/openssl/private/issues/698 + * { OSSL_FUNC_KEYMGMT_IMPORT_TYPES, (void (*)(void))mlkem_imexport_types }, + * { OSSL_FUNC_KEYMGMT_EXPORT_TYPES, (void (*)(void))mlkem_imexport_types }, + * { OSSL_FUNC_KEYMGMT_IMPORT, (void (*)(void))mlkem_export }, + * { OSSL_FUNC_KEYMGMT_EXPORT, (void (*)(void))mlkem_import }, + */ + OSSL_DISPATCH_END +}; diff --git a/test/build.info b/test/build.info index 5e511590d34..ad3831d5ed2 100644 --- a/test/build.info +++ b/test/build.info @@ -1011,6 +1011,13 @@ IF[{- !$disabled{tests} -}] INCLUDE[asn1_dsa_internal_test]=.. ../include ../apps/include DEPEND[asn1_dsa_internal_test]=../libcrypto.a libtestutil.a + IF[{- !$disabled{'mlkem'} -}] + PROGRAMS{noinst}=mlkem_internal_test + SOURCE[mlkem_internal_test]=mlkem_internal_test.c + INCLUDE[mlkem_internal_test]=../include ../apps/include + DEPEND[mlkem_internal_test]=../libcrypto.a libtestutil.a + ENDIF + SOURCE[keymgmt_internal_test]=keymgmt_internal_test.c INCLUDE[keymgmt_internal_test]=.. ../include ../apps/include DEPEND[keymgmt_internal_test]=../libcrypto.a libtestutil.a diff --git a/test/evp_extra_test.c b/test/evp_extra_test.c index b9124d02b56..a50ba9b4382 100644 --- a/test/evp_extra_test.c +++ b/test/evp_extra_test.c @@ -5912,6 +5912,109 @@ static int test_invalid_ctx_for_digest(void) return ret; } +static int test_ml_kem(void) +{ + EVP_PKEY *akey, *bkey = NULL; + int res = 0; + size_t publen; + unsigned char *rawpub = NULL; + EVP_PKEY_CTX *ctx = NULL; + unsigned char *wrpkey = NULL, *agenkey = NULL, *bgenkey = NULL; + size_t wrpkeylen, agenkeylen, bgenkeylen, i; + + /* Generate Alice's key */ + akey = EVP_PKEY_Q_keygen(testctx, NULL, "ML-KEM-768"); + if (!TEST_ptr(akey)) + goto err; + + /* Get the raw public key */ + publen = EVP_PKEY_get1_encoded_public_key(akey, &rawpub); + if (!TEST_size_t_gt(publen, 0)) + goto err; + + /* Create Bob's key and populate it with Alice's public key data */ + bkey = EVP_PKEY_new(); + if (!TEST_ptr(bkey)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_copy_parameters(bkey, akey), 0)) + goto err; + + if (!TEST_true(EVP_PKEY_set1_encoded_public_key(bkey, rawpub, publen))) + goto err; + + /* Encapsulate Bob's key */ + ctx = EVP_PKEY_CTX_new_from_pkey(testctx, bkey, NULL); + if (!TEST_ptr(ctx)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_encapsulate_init(ctx, NULL), 0)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_encapsulate(ctx, NULL, &wrpkeylen, NULL, + &bgenkeylen), 0)) + goto err; + + if (!TEST_size_t_gt(wrpkeylen, 0) || !TEST_size_t_gt(bgenkeylen, 0)) + goto err; + + wrpkey = OPENSSL_zalloc(wrpkeylen); + bgenkey = OPENSSL_zalloc(bgenkeylen); + if (!TEST_ptr(wrpkey) || !TEST_ptr(bgenkey)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_encapsulate(ctx, wrpkey, &wrpkeylen, bgenkey, + &bgenkeylen), 0)) + goto err; + + EVP_PKEY_CTX_free(ctx); + + /* Alice now decapsulates Bob's key */ + ctx = EVP_PKEY_CTX_new_from_pkey(testctx, akey, NULL); + if (!TEST_ptr(ctx)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_decapsulate_init(ctx, NULL), 0)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_decapsulate(ctx, NULL, &agenkeylen, wrpkey, + wrpkeylen), 0)) + goto err; + + if (!TEST_size_t_gt(agenkeylen, 0)) + goto err; + + agenkey = OPENSSL_zalloc(agenkeylen); + if (!TEST_ptr(agenkey)) + goto err; + + if (!TEST_int_gt(EVP_PKEY_decapsulate(ctx, agenkey, &agenkeylen, wrpkey, + wrpkeylen), 0)) + goto err; + + /* Hopefully we ended up with a shared key */ + if (!TEST_mem_eq(agenkey, agenkeylen, bgenkey, bgenkeylen)) + goto err; + + /* Verify we generated a non-zero shared key */ + for (i = 0; i < agenkeylen; i++) + if (agenkey[i] != 0) + break; + if (!TEST_size_t_ne(i, agenkeylen)) + return 0; + + res = 1; + err: + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(akey); + EVP_PKEY_free(bkey); + OPENSSL_free(rawpub); + OPENSSL_free(wrpkey); + OPENSSL_free(agenkey); + OPENSSL_free(bgenkey); + return res; +} + static int test_evp_cipher_pipeline(void) { OSSL_PROVIDER *fake_pipeline = NULL; @@ -6310,6 +6413,7 @@ int setup_tests(void) #endif ADD_TEST(test_invalid_ctx_for_digest); + ADD_TEST(test_ml_kem); ADD_TEST(test_evp_cipher_pipeline); diff --git a/test/mlkem_internal_test.c b/test/mlkem_internal_test.c new file mode 100644 index 00000000000..9dbd795748e --- /dev/null +++ b/test/mlkem_internal_test.c @@ -0,0 +1,90 @@ +/* + * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved. + * + * Licensed under the Apache License 2.0 (the "License"). You may not use + * this file except in compliance with the License. You can obtain a copy + * in the file LICENSE in the source distribution or at + * https://www.openssl.org/source/license.html + */ + +#include +#ifndef OPENSSL_NO_STDIO +# include +#endif + +#include + +#include +#include "testutil.h" +#include "testutil/output.h" + +int main(void) +{ + uint8_t out_encoded_public_key[OSSL_MLKEM768_PUBLIC_KEY_BYTES]; + uint8_t out_ciphertext[OSSL_MLKEM768_CIPHERTEXT_BYTES]; + uint8_t out_shared_secret[OSSL_MLKEM768_SHARED_SECRET_BYTES]; + uint8_t out_shared_secret2[OSSL_MLKEM768_SHARED_SECRET_BYTES]; + ossl_mlkem768_private_key private_key; + ossl_mlkem768_public_key public_key; + ossl_mlkem768_public_key recreated_public_key; + uint8_t *p1, *p2; + ossl_mlkem_ctx *mlkem_ctx = ossl_mlkem_newctx(NULL, NULL); + int ret = 1; + + /* enable TEST_* API */ + test_open_streams(); + + /* first, generate a key pair */ + if (!ossl_mlkem768_generate_key(out_encoded_public_key, NULL, + &private_key, mlkem_ctx)) { + ret = -1; + goto end; + } + /* public key component to be created from private key */ + if (!ossl_mlkem768_public_from_private(&public_key, &private_key)) { + ret = -2; + goto end; + } + /* try to re-create public key structure from encoded public key */ + if (!ossl_mlkem768_recreate_public_key(out_encoded_public_key, + &recreated_public_key, mlkem_ctx)) { + ret = -3; + goto end; + } + /* validate identity of both public key structures */ + p1 = (uint8_t *)&public_key; + p2 = (uint8_t *)&recreated_public_key; + if (!TEST_int_eq(memcmp(p1, p2, sizeof(public_key)), 0)) { + ret = -4; + goto end; + } + /* encaps - decaps test: validate shared secret identity */ + if (!ossl_mlkem768_encap(out_ciphertext, out_shared_secret, + &recreated_public_key, mlkem_ctx)) { + ret = -5; + goto end; + } + if (!ossl_mlkem768_decap(out_shared_secret2, out_ciphertext, + OSSL_MLKEM768_CIPHERTEXT_BYTES, &private_key, mlkem_ctx)) { + ret = -6; + goto end; + } + if (!TEST_int_eq(memcmp(out_shared_secret, out_shared_secret2, + OSSL_MLKEM768_SHARED_SECRET_BYTES), 0)) { + ret = -7; + goto end; + } + /* so far so good, now a quick negative test by breaking the ciphertext */ + out_ciphertext[0]++; + if (!ossl_mlkem768_decap(out_shared_secret2, out_ciphertext, + OSSL_MLKEM768_CIPHERTEXT_BYTES, &private_key, mlkem_ctx)) + goto end; + /* If decap passed, ensure we at least have a mismatch */ + if (!TEST_int_ne(memcmp(out_shared_secret, out_shared_secret2, + OSSL_MLKEM768_SHARED_SECRET_BYTES), 0)) + ret = -8; + +end: + ossl_mlkem_ctx_free(mlkem_ctx); + return ret; +}