#define bit0(b) ((b) & 1)
#define bitn(n, b) (((b) >> n) & 1)
-#define DEGREE ML_KEM_DEGREE
-#define BARRETT_SHIFT (2 * ML_KEM_LOG2PRIME)
+/*
+ * 12 bits are sufficient to losslessly represent values in [0, q-1].
+ * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
+ */
+#define DEGREE ML_KEM_DEGREE
+#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
+#define LOG2PRIME 12
+#define BARRETT_SHIFT (2 * LOG2PRIME)
+
#ifdef SHA3_BLOCKSIZE
-# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
+# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
#endif
/*
uint16_t c[ML_KEM_DEGREE];
} scalar;
-/* General form of public and private key storage */
+/* Key material allocation layout */
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
- struct ossl_ml_kem_##name##_st { \
+ struct name##_alloc { \
/* Public vector |t| */ \
scalar tbuf[(rank)]; \
- /* Pre-computed matrix |m| */ \
+ /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
scalar mbuf[(rank)*(rank)] \
/* optional private key data */ \
private_sz \
/* Declare variant-specific public and private storage */
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
- DECLARE_ML_KEM_KEYDATA(bits##_puballoc, ML_KEM_##bits##_RANK,;); \
- DECLARE_ML_KEM_KEYDATA(bits##_prvalloc, ML_KEM_##bits##_RANK,;\
+ DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
+ DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
scalar sbuf[ML_KEM_##bits##_RANK]; \
- uint8_t zbuf[ML_KEM_RANDOM_BYTES];)
+ uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
typedef __owur
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
-
-/* The wire form of a losslessly encoded vector (12-bits per element). */
-#define ML_KEM_VECTOR_BYTES(rank) \
- ((3 * ML_KEM_DEGREE / 2) * (rank))
-
-/*
- * The expanded internal form stores each coefficient as a 16-bit unsigned int
- */
-#define ML_KEM_VECALLOC_BYTES(rank) \
- (2 * ML_KEM_DEGREE * (rank))
-
-/*
- * The wire-form public key consists of the lossless encoding of the vector
- * "t" = "A" * "s" + "e", followed by public seed "rho".
- */
-#define ML_KEM_PUBKEY_BYTES(rank) \
- (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES)
+static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s);
/*
- * Our internal serialised private key concatenates serialisations of "s", the
- * public key, the public key hash, and the failure secret "z".
+ * The wire-form of a losslessly encoded vector uses 12-bits per element.
+ *
+ * The wire-form public key consists of the lossless encoding of the public
+ * vector |t|, followed by the public seed |rho|.
+ *
+ * Our serialised private key concatenates serialisations of the private vector
+ * |s|, the public key, the public key hash, and the failure secret |z|.
*/
-#define ML_KEM_PRVKEY_BYTES(rank) \
- (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \
- + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES)
+#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
+#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
+#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
/*
* Encapsulation produces a vector "u" and a scalar "v", whose coordinates
* "dv" bits, respectively. This encoding is the ciphertext input for
* decapsulation.
*/
-#define ML_KEM_U_VECTOR_BYTES(rank, du) \
- ((ML_KEM_DEGREE / 8) * (du) * (rank))
-#define ML_KEM_V_SCALAR_BYTES(dv) \
- ((ML_KEM_DEGREE / 8) * (dv))
-#define ML_KEM_CTEXT_BYTES(rank, du, dv) \
- (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv))
-
-/*
- * Variant-specific sizes
- */
-#define ML_KEM_512_VECTOR_BYTES \
- ML_KEM_VECTOR_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_VECTOR_BYTES \
- ML_KEM_VECTOR_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_VECTOR_BYTES \
- ML_KEM_VECTOR_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_PUBLIC_KEY_BYTES \
- ML_KEM_PUBKEY_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_PUBLIC_KEY_BYTES \
- ML_KEM_PUBKEY_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_PUBLIC_KEY_BYTES \
- ML_KEM_PUBKEY_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_PRIVATE_KEY_BYTES \
- ML_KEM_PRVKEY_BYTES(ML_KEM_512_RANK)
-#define ML_KEM_768_PRIVATE_KEY_BYTES \
- ML_KEM_PRVKEY_BYTES(ML_KEM_768_RANK)
-#define ML_KEM_1024_PRIVATE_KEY_BYTES \
- ML_KEM_PRVKEY_BYTES(ML_KEM_1024_RANK)
-
-#define ML_KEM_512_U_VECTOR_BYTES \
- ML_KEM_U_VECTOR_BYTES(ML_KEM_512_RANK, ML_KEM_512_DU)
-#define ML_KEM_768_U_VECTOR_BYTES \
- ML_KEM_U_VECTOR_BYTES(ML_KEM_768_RANK, ML_KEM_768_DU)
-#define ML_KEM_1024_U_VECTOR_BYTES \
- ML_KEM_U_VECTOR_BYTES(ML_KEM_1024_RANK, ML_KEM_1024_DU)
-
-#define ML_KEM_512_V_SCALAR_BYTES \
- ML_KEM_V_SCALAR_BYTES(ML_KEM_512_DV)
-#define ML_KEM_768_V_SCALAR_BYTES \
- ML_KEM_V_SCALAR_BYTES(ML_KEM_768_DV)
-#define ML_KEM_1024_V_SCALAR_BYTES \
- ML_KEM_V_SCALAR_BYTES(ML_KEM_1024_DV)
-
-#define ML_KEM_512_CIPHERTEXT_BYTES \
- (ML_KEM_512_U_VECTOR_BYTES + ML_KEM_512_V_SCALAR_BYTES)
-#define ML_KEM_768_CIPHERTEXT_BYTES \
- (ML_KEM_768_U_VECTOR_BYTES + ML_KEM_768_V_SCALAR_BYTES)
-#define ML_KEM_1024_CIPHERTEXT_BYTES \
- (ML_KEM_1024_U_VECTOR_BYTES + ML_KEM_1024_V_SCALAR_BYTES)
+#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
+#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
+#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
/*
* Per-variant fixed parameters
static const ML_KEM_VINFO vinfo_map[3] = {
{
"ML-KEM-512",
- ML_KEM_512_VECTOR_BYTES,
- ML_KEM_512_PRIVATE_KEY_BYTES,
- ML_KEM_512_PUBLIC_KEY_BYTES,
- ML_KEM_512_CIPHERTEXT_BYTES,
- ML_KEM_512_U_VECTOR_BYTES,
- sizeof(struct ossl_ml_kem_512_puballoc_st),
- sizeof(struct ossl_ml_kem_512_prvalloc_st),
- ML_KEM_512,
- 512,
+ PRVKEY_BYTES(512),
+ sizeof(struct prvkey_512_alloc),
+ PUBKEY_BYTES(512),
+ sizeof(struct pubkey_512_alloc),
+ CTEXT_BYTES(512),
+ VECTOR_BYTES(512),
+ U_VECTOR_BYTES(512),
+ ML_KEM_512_VARIANT,
+ ML_KEM_512_BITS,
ML_KEM_512_RANK,
ML_KEM_512_DU,
ML_KEM_512_DV,
- ML_KEM_512_RNGSEC
+ ML_KEM_512_SECBITS
},
{
"ML-KEM-768",
- ML_KEM_768_VECTOR_BYTES,
- ML_KEM_768_PRIVATE_KEY_BYTES,
- ML_KEM_768_PUBLIC_KEY_BYTES,
- ML_KEM_768_CIPHERTEXT_BYTES,
- ML_KEM_768_U_VECTOR_BYTES,
- sizeof(struct ossl_ml_kem_768_puballoc_st),
- sizeof(struct ossl_ml_kem_768_prvalloc_st),
- ML_KEM_768,
- 768,
+ PRVKEY_BYTES(768),
+ sizeof(struct prvkey_768_alloc),
+ PUBKEY_BYTES(768),
+ sizeof(struct pubkey_768_alloc),
+ CTEXT_BYTES(768),
+ VECTOR_BYTES(768),
+ U_VECTOR_BYTES(768),
+ ML_KEM_768_VARIANT,
+ ML_KEM_768_BITS,
ML_KEM_768_RANK,
ML_KEM_768_DU,
ML_KEM_768_DV,
- ML_KEM_768_RNGSEC
+ ML_KEM_768_SECBITS
},
{
"ML-KEM-1024",
- ML_KEM_1024_VECTOR_BYTES,
- ML_KEM_1024_PRIVATE_KEY_BYTES,
- ML_KEM_1024_PUBLIC_KEY_BYTES,
- ML_KEM_1024_CIPHERTEXT_BYTES,
- ML_KEM_1024_U_VECTOR_BYTES,
- sizeof(struct ossl_ml_kem_1024_puballoc_st),
- sizeof(struct ossl_ml_kem_1024_prvalloc_st),
- ML_KEM_1024,
- 1024,
+ PRVKEY_BYTES(1024),
+ sizeof(struct prvkey_1024_alloc),
+ PUBKEY_BYTES(1024),
+ sizeof(struct pubkey_1024_alloc),
+ CTEXT_BYTES(1024),
+ VECTOR_BYTES(1024),
+ U_VECTOR_BYTES(1024),
+ ML_KEM_1024_VARIANT,
+ ML_KEM_1024_BITS,
ML_KEM_1024_RANK,
ML_KEM_1024_DU,
ML_KEM_1024_DV,
- ML_KEM_1024_RNGSEC
+ ML_KEM_1024_SECBITS
}
};
static const unsigned int kBarrettShift = BARRETT_SHIFT;
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
-static const uint16_t kInverseDegree = ML_KEM_INVERSE_DEGREE;
+static const uint16_t kInverseDegree = INVERSE_DEGREE;
/*
* Python helper:
&& single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
}
+/* Incremental hash_h of expanded public key */
+static int
+hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
+ EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
+{
+ const ML_KEM_VINFO *vinfo = key->vinfo;
+ const scalar *t = key->t, *end = t + vinfo->rank;
+ unsigned int sz;
+
+ if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
+ return 0;
+
+ while (t < end) {
+ uint8_t buf[3 * DEGREE / 2];
+
+ scalar_encode_12(buf, t++);
+ if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
+ return 0;
+ }
+
+ if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
+ return 0;
+ return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
+ && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
+}
+
/*
* FIPS 203, Section 4.1, equation (4.5): G. SHA3-512 hash of a variable
* length input, producing 64 bytes of output, in particular the seeds
static void scalar_mult(scalar *out, const scalar *lhs,
const scalar *rhs)
{
- int i;
- uint32_t real_real, img_img, real_img, img_real;
-
- for (i = 0; i < DEGREE / 2; i++) {
- real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
- img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
- real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
- img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
- out->c[2 * i] =
- reduce(real_real +
- (uint32_t)reduce(img_img) * kModRoots[i]);
- out->c[2 * i + 1] = reduce(img_real + real_img);
+ uint16_t *curr = out->c, *end = curr + DEGREE;
+ const uint16_t *lc = lhs->c, *rc = rhs->c;
+ const uint16_t *roots = kModRoots;
+
+ while (curr < end) {
+ uint32_t l0 = *lc++, r0 = *rc++;
+ uint32_t l1 = *lc++, r1 = *rc++;
+ uint32_t zetapow = *roots++;
+
+ *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
+ *curr++ = reduce(l0 * r1 + l1 * r0);
}
}
+/* Above, but add the result to an existing scalar */
static ossl_inline
void scalar_mult_add(scalar *out, const scalar *lhs,
const scalar *rhs)
{
- scalar product;
+ uint16_t *curr = out->c, *end = curr + DEGREE;
+ const uint16_t *lc = lhs->c, *rc = rhs->c;
+ const uint16_t *roots = kModRoots;
- scalar_mult(&product, lhs, rhs);
- scalar_add(out, &product);
+ while (curr < end) {
+ uint32_t l0 = *lc++, r0 = *rc++;
+ uint32_t l1 = *lc++, r1 = *rc++;
+ uint16_t *c0 = curr++;
+ uint16_t *c1 = curr++;
+ uint32_t zetapow = *roots++;
+
+ *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
+ *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
+ }
}
static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
*/
static void scalar_encode_12(uint8_t out[3 * DEGREE / 2], const scalar *s)
{
- const uint16_t *c = s->c;
- int i;
+ const uint16_t *curr = s->c, *end = curr + DEGREE;
+ uint16_t c1, c2;
- for (i = 0; i < DEGREE / 2; ++i) {
- uint16_t c1 = *c++;
- uint16_t c2 = *c++;
+#define encode_two \
+ c1 = *curr++; \
+ c2 = *curr++; \
+ *out++ = (uint8_t) c1; \
+ *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4)); \
+ *out++ = (uint8_t) (c2 >> 4)
- *out++ = (uint8_t) c1;
- *out++ = (uint8_t) (((c1 >> 8) & 0x0f) | ((c2 & 0x0f) << 4));
- *out++ = (uint8_t) (c2 >> 4);
+ while (curr < end) {
+ encode_two;
+ encode_two;
+ encode_two;
+ encode_two;
}
+#undef encode_two
}
/*
}
/*
- * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns one on
- * success or zero if any parsed value is >= |ML_KEM_PRIME|.
+ * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
+ * if any parsed value is >= |ML_KEM_PRIME|. The resulting scalars are
+ * then decompressed and transformed via the NTT.
*
* Note: Used only in decrypt_cpa(), which returns void and so does not check
* the return value of this function. Side-channels are fine when the input
* ciphertext to decap() is simply syntactically invalid.
*/
-static void vector_decode(scalar *out, const uint8_t *in, int bits, int rank)
+static void
+vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
{
int stride = bits * DEGREE / 8;
- for (; rank-- > 0; in += stride)
- if (!scalar_decode(out++, in, bits))
+ for (; rank-- > 0; in += stride, ++out) {
+ if (!scalar_decode(out, in, bits))
return;
+ scalar_decompress(out, bits);
+ scalar_ntt(out);
+ }
}
/* vector_encode(), specialised to bits == 12. */
scalar_compress(a++, bits);
}
-/* In-place decompression of each scalar component */
-static void vector_decompress(scalar *a, int bits, int rank)
-{
- while (rank-- > 0)
- scalar_decompress(a++, bits);
-}
-
/* The output scalar must not overlap with the inputs */
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
int rank)
scalar_mult_add(out, ++lhs, ++rhs);
}
-/* In-place NTT transform of a vector */
-static void vector_ntt(scalar *a, int rank)
-{
- while (rank-- > 0)
- scalar_ntt(a++);
-}
-
-/* In-place inverse NTT transform of a vector */
-static void vector_inverse_ntt(scalar *a, int rank)
-{
- while (rank-- > 0)
- scalar_inverse_ntt(a++);
-}
-
-/* Here, the output vector must not overlap with the inputs */
+/*
+ * Here, the output vector must not overlap with the inputs, the result is
+ * directly subjected to inverse NTT.
+ */
static void
-matrix_mult(scalar *out, const scalar *m, const scalar *a, int rank)
+matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
{
const scalar *ar;
int i, j;
scalar_mult(out, m++, ar = a);
for (j = rank - 1; j > 0; --j)
scalar_mult_add(out, m++, ++ar);
+ scalar_inverse_ntt(out);
}
}
/* Here, the output vector must not overlap with the inputs */
static void
-matrix_mult_transpose(scalar *out, const scalar *m, const scalar *a, int rank)
+matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
{
const scalar *mc = m, *mr, *ar;
int i, j;
- for (i = rank; i-- > 0; ++out) {
- scalar_mult(out, mr = mc++, ar = a);
+ for (i = rank; i-- > 0; ++out) {
+ scalar_mult_add(out, mr = mc++, ar = a);
for (j = rank; --j > 0; )
scalar_mult_add(out, (mr += rank), ++ar);
}
return 1;
}
+/*
+ * As above plus NTT transform.
+ */
+static __owur
+int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
+ const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
+ EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
+{
+ uint8_t input[ML_KEM_RANDOM_BYTES + 1];
+
+ memcpy(input, seed, ML_KEM_RANDOM_BYTES);
+ while (rank-- > 0) {
+ input[ML_KEM_RANDOM_BYTES] = (*counter)++;
+ if (!cbd(out, input, mdctx, key))
+ return 0;
+ scalar_ntt(out++);
+ }
+ return 1;
+}
+
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
-static CBD_FUNC const cbd1[ML_KEM_1024 + 1] = { cbd_3, cbd_2, cbd_2 };
+static CBD_FUNC const cbd1[ML_KEM_1024_VARIANT + 1] = { cbd_3, cbd_2, cbd_2 };
/*
* FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
int dv = vinfo->dv;
/* FIPS 203 "y" vector */
- if (!gencbd_vector(y, cbd_1, &counter, r, rank, mdctx, key))
+ if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
return 0;
- vector_ntt(y, rank);
/* FIPS 203 "v" scalar */
inner_product(&v, key->t, y, rank);
scalar_inverse_ntt(&v);
/* FIPS 203 "u" vector */
- matrix_mult(u, key->m, y, rank);
- vector_inverse_ntt(u, rank);
+ matrix_mult_intt(u, key->m, y, rank);
/* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
int du = vinfo->du;
int dv = vinfo->dv;
- vector_decode(u, ctext, du, rank);
- vector_decompress(u, du, rank);
- vector_ntt(u, rank);
+ vector_decode_decompress_ntt(u, ctext, du, rank);
scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
scalar_decompress(&v, dv);
inner_product(&mask, key->s, u, rank);
* The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
* inlined.
*
- * The caller MUST also pass a pre-allocated scratch buffer |tmp| with room for
- * at least one "vector" (rank * sizeof(scalar)), and a pre-allocated digest
- * context that is not shared with any concurrent computation.
+ * The caller MUST pass a pre-allocated digest context that is not shared with
+ * any concurrent computation.
*
- * This function outputs the serialised wire-form |ek| public key into the
- * provided |pubenc| buffer, and generates the content of the |rho|, |pkhash|,
- * |t|, |m|, |s| and |z| components of the private |key| (which must have
- * preallocated space for these).
+ * This function optionally outputs the serialised wire-form |ek| public key
+ * into the provided |pubenc| buffer, and generates the content of the |rho|,
+ * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
+ * have preallocated space for these).
*
* Keys are computed from a 32-byte random |d| plus the 1 byte rank for
* domain separation. These are concatenated and hashed to produce a pair of
static __owur
int genkey(const uint8_t d[ML_KEM_RANDOM_BYTES],
const uint8_t z[ML_KEM_RANDOM_BYTES],
- scalar *tmp, EVP_MD_CTX *mdctx,
- uint8_t *pubenc, ML_KEM_KEY *key)
+ EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
{
uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
return 0;
memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
+ /* FIPS 203 |e| vector is initial value of key->t */
if (!matrix_expand(mdctx, key)
- || !gencbd_vector(key->s, cbd_1, &counter, sigma, rank, mdctx, key))
- return 0;
- vector_ntt(key->s, rank);
- /* FIPS 203 |e| vector */
- if (!gencbd_vector(tmp, cbd_1, &counter, sigma, rank, mdctx, key))
+ || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
+ || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
return 0;
- vector_ntt(tmp, rank);
- /* Fill in the public key */
- matrix_mult_transpose(key->t, key->m, key->s, rank);
- vector_add(key->t, tmp, rank);
- encode_pubkey(pubenc, key);
- if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
- return 0;
+ /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
+ matrix_mult_transpose_add(key->t, key->m, key->s, rank);
+
+ if (pubenc == NULL) {
+ /* Incremental digest of public key without in-full serialisation. */
+ if (!hash_h_pubkey(key->pkhash, mdctx, key))
+ return 0;
+ } else {
+ encode_pubkey(pubenc, key);
+ if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
+ return 0;
+ }
- /* Save "z" portion of seed for "implicit rejection" on failure */
+ /* Save |z| portion of seed for "implicit rejection" on failure. */
memcpy(key->z, z, ML_KEM_RANDOM_BYTES);
+ /* Also save |d| portion, in suport of likely alternative key format. */
+ memcpy(key->d = key->z + ML_KEM_RANDOM_BYTES, d, ML_KEM_RANDOM_BYTES);
return 1;
}
return 0;
/* A public key needs space for |t| and |m| */
key->m = (key->t = p) + rank;
- /* A private key also needs space for |s| and |z| */
+ /*
+ * A private key also needs space for |s| and |z|.
+ * The |z| buffer always includes additional space for |d|, but a key's |d|
+ * pointer is left NULL when parsed from the NIST format, which omits that
+ * information. Only keys generated from a (d, z) seed pair will have a
+ * non-NULL |d| pointer.
+ */
if (private)
key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
- else
- key->z = (uint8_t *)(key->s = NULL);
return 1;
}
if (key->t == NULL)
return;
OPENSSL_free(key->t);
- key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
+ key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
}
/*
/* Retrieve the parameters of one of the ML-KEM variants */
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant)
{
- if (variant > ML_KEM_1024)
+ if (variant > ML_KEM_1024_VARIANT)
return NULL;
return &vinfo_map[variant];
}
key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
- key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
+ key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
if (key->shake128_md != NULL
&& key->shake256_md != NULL
break;
case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
+ /* Duplicated keys retain |d|, if available */
+ if (key->d != NULL)
+ ret->d = ret->z + ML_KEM_RANDOM_BYTES;
break;
}
/*-
* Cleanse any sensitive data:
* - The private vector |s| is immediately followed by the FO failure
- * secret |z|, we can cleanse both in one call.
+ * secret |z|, and seed |d|, we can cleanse all three in one call.
*/
if (key->s != NULL)
- OPENSSL_cleanse(key->s, key->vinfo->vector_bytes + ML_KEM_RANDOM_BYTES);
+ OPENSSL_cleanse(key->s,
+ key->vinfo->vector_bytes + 2 * ML_KEM_RANDOM_BYTES);
/* Free the key material */
OPENSSL_free(key->t);
return ret;
}
-/* Parse input as a new private key */
+/* Parse input as a new private key */
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
ML_KEM_KEY *key)
{
}
/*
- * Generate a new keypair from a given seed, giving a deterministic result for
- * running tests. The caller can elect to not collect the encoded public key.
+ * Generate a new keypair either from the input seeds (when non-null), yielding
+ * a deterministic result for running tests, or securely generated random data.
*/
-int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seedlen,
- uint8_t *pubenc, size_t publen,
- ML_KEM_KEY *key)
+int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z,
+ uint8_t *pubenc, size_t publen,
+ ML_KEM_KEY *key)
{
+ uint8_t tmpseed[2 * ML_KEM_RANDOM_BYTES];
EVP_MD_CTX *mdctx = NULL;
const ML_KEM_VINFO *vinfo;
int ret = 0;
return 0;
vinfo = key->vinfo;
- if (seed == NULL || seedlen != ML_KEM_SEED_BYTES
+ /* Both seeds or neither must be NULL */
+ if (((d == NULL) ^ (z == NULL)) != 0
|| (pubenc != NULL && publen != vinfo->pubkey_bytes)
|| (mdctx = EVP_MD_CTX_new()) == NULL)
return 0;
- if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key)) {
- const uint8_t *d = seed;
- const uint8_t *z = seed + ML_KEM_RANDOM_BYTES;
-
- /*-
- * This avoids the need to handle allocation failures for two (max 2KB
- * each) vectors and (if the caller does not want the public key) an
- * encoded public key (max 1568 bytes), that are never retained on
- * return from this function.
- * We stack-allocate these.
- */
-# define case_genkey_seed(bits) \
- case ML_KEM_##bits: \
- if (pubenc != NULL) \
- { \
- scalar tmp[ML_KEM_##bits##_RANK]; \
- \
- ret = genkey(d, z, tmp, mdctx, pubenc, key); \
- } else { \
- scalar tmp[ML_KEM_##bits##_RANK]; \
- uint8_t encbuf[ML_KEM_##bits##_PUBLIC_KEY_BYTES]; \
- \
- ret = genkey(d, z, tmp, mdctx, encbuf, key); \
- } \
- break
- switch (vinfo->variant) {
- case_genkey_seed(512);
- case_genkey_seed(768);
- case_genkey_seed(1024);
- }
-# undef case_genkey_seed
+ if (d == NULL) {
+ if (RAND_priv_bytes_ex(key->libctx, tmpseed, 2 * ML_KEM_RANDOM_BYTES,
+ key->vinfo->secbits) <= 0)
+ return 0;
+ d = tmpseed;
+ z = tmpseed + ML_KEM_RANDOM_BYTES;
}
+ if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
+ ret = genkey(d, z, mdctx, pubenc, key);
+
if (!ret)
free_storage(key);
EVP_MD_CTX_free(mdctx);
return ret;
}
-/*
- * Generate a new keypair from a random seed, using the library context's
- * private DRBG. The caller can elect to not collect the seed or the encoded
- * public key.
- */
-int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen,
- uint8_t *pubenc, size_t publen,
- ML_KEM_KEY *key)
-{
- uint8_t tmpseed[ML_KEM_SEED_BYTES];
- uint8_t *sptr = seed == NULL ? tmpseed : seed;
-
- if (key == NULL
- || ossl_ml_kem_have_pubkey(key)
- || (seed != NULL && seedlen != sizeof(tmpseed)))
- return 0;
-
- if (RAND_priv_bytes_ex(key->libctx, sptr, sizeof(tmpseed),
- key->vinfo->secbits) <= 0)
- return 0;
-
- return ossl_ml_kem_genkey_seed(sptr, sizeof(tmpseed), pubenc, publen, key);
-}
-
/*
* FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
* This is the deterministic version with randomness supplied externally.
* We stack-allocate these.
*/
# define case_encap_seed(bits) \
- case ML_KEM_##bits: \
+ case ML_KEM_##bits##_VARIANT: \
{ \
scalar tmp[2 * ML_KEM_##bits##_RANK]; \
\
* retained on return from this function.
* We stack-allocate these.
*/
-# define case_decap(bits) \
- case ML_KEM_##bits: \
- { \
- uint8_t cbuf[ML_KEM_##bits##_CIPHERTEXT_BYTES]; \
- scalar tmp[2 * ML_KEM_##bits##_RANK]; \
- \
- ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
- EVP_MD_CTX_free(mdctx); \
- return ret; \
+# define case_decap(bits) \
+ case ML_KEM_##bits##_VARIANT: \
+ { \
+ uint8_t cbuf[CTEXT_BYTES(bits)]; \
+ scalar tmp[2 * ML_KEM_##bits##_RANK]; \
+ \
+ ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
+ EVP_MD_CTX_free(mdctx); \
+ return ret; \
}
switch (vinfo->variant) {
case_decap(512);
/*
* Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
-
+ *
* Licensed under the Apache License 2.0 (the "License"). You may not use
* this file except in compliance with the License. You can obtain a copy
* in the file LICENSE in the source distribution or at
# include <openssl/e_os2.h>
# include <openssl/core_dispatch.h>
# include <crypto/evp.h>
-# include <crypto/types.h>
# define ML_KEM_DEGREE 256
/*
* of unity, the polynomial (X^256+1) splits in Z_q[X] into 128 irreducible
* quadratic factors of the form (X^2 - zeta^(2i + 1)). This is used to
* implement efficient multiplication in the ring R_q via the "NTT" transform.
- *
- * 12 bits are sufficient to losslessly represent values in [0, q-1].
- * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
*/
# define ML_KEM_PRIME (ML_KEM_DEGREE * 13 + 1)
-# define ML_KEM_LOG2PRIME 12
-# define ML_KEM_INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
/*
* Various ML-KEM primitives need random input, 32-bytes at a time. Key
* - "secbits" is required security strength of the RNG for the random inputs.
*/
-/*
- * The wire form of a losslessly encoded vector (12-bits per element)
- */
-# define ML_KEM_VECTOR_BYTES(rank) \
- ((3 * ML_KEM_DEGREE / 2) * (rank))
-
-/*
- * The wire-form public key consists of the lossless encoding of the vector
- * "t" = "A" * "s" + "e", followed by public seed "rho".
- */
-# define ML_KEM_PUBKEY_BYTES(rank) \
- (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_RANDOM_BYTES)
-
-/*
- * Our internal serialised private key concatenates serialisations of "s", the
- * public key, the public key hash, and the failure secret "z".
- */
-# define ML_KEM_PRVKEY_BYTES(rank) \
- (ML_KEM_VECTOR_BYTES(rank) + ML_KEM_PUBKEY_BYTES(rank) \
- + ML_KEM_PKHASH_BYTES + ML_KEM_RANDOM_BYTES)
-
-/*
- * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
- * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
- * "dv" bits, respectively. This encoding is the ciphertext input for
- * decapsulation.
- */
-# define ML_KEM_U_VECTOR_BYTES(rank, du) \
- ((ML_KEM_DEGREE / 8) * (du) * (rank))
-# define ML_KEM_V_SCALAR_BYTES(dv) \
- ((ML_KEM_DEGREE / 8) * (dv))
-# define ML_KEM_CTEXT_BYTES(rank, du, dv) \
- (ML_KEM_U_VECTOR_BYTES(rank, du) + ML_KEM_V_SCALAR_BYTES(dv))
-
/*
* Variant-specific constants and structures
* -----------------------------------------
*/
-# define ML_KEM_512_RANK 2
-# define ML_KEM_512_ETA1 3
-# define ML_KEM_512_ETA2 2
-# define ML_KEM_512_DU 10
-# define ML_KEM_512_DV 4
-# define ML_KEM_512_RNGSEC 128
-
-# define ML_KEM_768_RANK 3
-# define ML_KEM_768_ETA1 2
-# define ML_KEM_768_ETA2 2
-# define ML_KEM_768_DU 10
-# define ML_KEM_768_DV 4
-# define ML_KEM_768_RNGSEC 192
-
-# define ML_KEM_1024_RANK 4
-# define ML_KEM_1024_ETA1 2
-# define ML_KEM_1024_ETA2 2
-# define ML_KEM_1024_DU 11
-# define ML_KEM_1024_DV 5
-# define ML_KEM_1024_RNGSEC 256
+# define ML_KEM_512_VARIANT 0
+# define ML_KEM_512_BITS 512
+# define ML_KEM_512_RANK 2
+# define ML_KEM_512_ETA1 3
+# define ML_KEM_512_ETA2 2
+# define ML_KEM_512_DU 10
+# define ML_KEM_512_DV 4
+# define ML_KEM_512_SECBITS 128
+
+# define ML_KEM_768_VARIANT 1
+# define ML_KEM_768_BITS 768
+# define ML_KEM_768_RANK 3
+# define ML_KEM_768_ETA1 2
+# define ML_KEM_768_ETA2 2
+# define ML_KEM_768_DU 10
+# define ML_KEM_768_DV 4
+# define ML_KEM_768_SECBITS 192
+
+# define ML_KEM_1024_VARIANT 2
+# define ML_KEM_1024_BITS 1024
+# define ML_KEM_1024_RANK 4
+# define ML_KEM_1024_ETA1 2
+# define ML_KEM_1024_ETA2 2
+# define ML_KEM_1024_DU 11
+# define ML_KEM_1024_DV 5
+# define ML_KEM_1024_SECBITS 256
/*
* External variant-specific API
* -----------------------------
*/
-/*
- * Each variant parameter set is associated with an ordinal number which
- * represents that parameter set.
- */
-# define ML_KEM_512 0
-# define ML_KEM_768 1
-# define ML_KEM_1024 2
-
typedef struct {
const char *algorithm_name;
- size_t vector_bytes;
size_t prvkey_bytes;
+ size_t prvalloc;
size_t pubkey_bytes;
+ size_t puballoc;
size_t ctext_bytes;
+ size_t vector_bytes;
size_t u_vector_bytes;
- size_t puballoc;
- size_t prvalloc;
int variant;
int bits;
int rank;
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int variant);
/* Known as ML_KEM_KEY via crypto/types.h */
-struct ossl_ml_kem_key_st {
+typedef struct ossl_ml_kem_key_st {
/* Variant metadata, for one of ML-KEM-{512,768,1024} */
const ML_KEM_VINFO *vinfo;
struct ossl_ml_kem_scalar_st *m; /* Pre-computed pubkey matrix */
struct ossl_ml_kem_scalar_st *s; /* Private key secret vector */
uint8_t *z; /* Private key FO failure secret */
+ uint8_t *d; /* Private key seed */
/* Fixed-size/offset built-ins */
uint8_t rho[ML_KEM_RANDOM_BYTES]; /* Matrix recovery seed */
uint8_t pkhash[ML_KEM_PKHASH_BYTES]; /* Hash of wire-form public key */
-};
+} ML_KEM_KEY;
/* The public key is always present, when the private is */
# define ossl_ml_kem_key_vinfo(key) ((key)->vinfo)
* ----- ML-KEM key lifecycle
*/
-# ifndef OPENSSL_NO_ML_KEM
-
/*
* Allocate a "bare" key for given ML-KEM variant. Initially without any public
* or private key material.
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
ML_KEM_KEY *key);
__owur
-int ossl_ml_kem_genkey_rand(uint8_t *seed, size_t seedlen,
- uint8_t *pubenc, size_t publen,
- ML_KEM_KEY *key);
-__owur
-int ossl_ml_kem_genkey_seed(const uint8_t *seed, size_t seed_len,
- uint8_t *pubenc, size_t publen,
- ML_KEM_KEY *key);
+int ossl_ml_kem_genkey(const uint8_t *d, const uint8_t *z,
+ uint8_t *pubenc, size_t publen,
+ ML_KEM_KEY *key);
/*
* Perform an ML-KEM operation with a given ML-KEM key. The key can generally
__owur
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2);
-# endif /* OPENSSL_NO_ML_KEM */
-
#endif /* OPENSSL_HEADER_ML_KEM_H */