]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
ML_DSA: Use pointers instead of arrays for polynomials in Vectors and Matrix.
authorslontis <shane.lontis@oracle.com>
Mon, 9 Dec 2024 23:24:05 +0000 (10:24 +1100)
committerTomas Mraz <tomas@openssl.org>
Fri, 14 Feb 2025 09:46:03 +0000 (10:46 +0100)
A DSA_KEY when created will alloc enough space to hold its k & l
vectors and then just set the vectors to point to the allocated blob.

Local Vectors and Matricies can then be initialised in a similar way by
passing them an array of Polnomials that are on the local stack.

Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/26127)

crypto/ml_dsa/ml_dsa_key.c
crypto/ml_dsa/ml_dsa_key.h
crypto/ml_dsa/ml_dsa_matrix.c
crypto/ml_dsa/ml_dsa_matrix.h
crypto/ml_dsa/ml_dsa_ntt.c
crypto/ml_dsa/ml_dsa_sample.c
crypto/ml_dsa/ml_dsa_sign.c [new file with mode: 0644]
crypto/ml_dsa/ml_dsa_vector.h

index 9ab46ab1b5d632521f8efc3542307913f7fcfc85..f8a488f0d1d8aa6a872d5c6957662b96eee21e2e 100644 (file)
 ML_DSA_KEY *ossl_ml_dsa_key_new(OSSL_LIB_CTX *libctx, const char *alg)
 {
     ML_DSA_KEY *ret;
+    size_t sz;
     const ML_DSA_PARAMS *params = ossl_ml_dsa_params_get(alg);
+    POLY *poly;
 
     if (params == NULL)
         return NULL;
 
-    ret = OPENSSL_zalloc(sizeof(*ret));
+    sz = sizeof(POLY) * (params->k * 3 + params->l);
+    ret = OPENSSL_zalloc(sizeof(*ret) + sz);
     if (ret != NULL) {
         if (!CRYPTO_NEW_REF(&ret->references, 1)) {
             OPENSSL_free(ret);
@@ -39,10 +42,11 @@ ML_DSA_KEY *ossl_ml_dsa_key_new(OSSL_LIB_CTX *libctx, const char *alg)
         }
         ret->libctx = libctx;
         ret->params = params;
-        vector_init(&ret->t0, params->k);
-        vector_init(&ret->t1, params->k);
-        vector_init(&ret->s2, params->k);
-        vector_init(&ret->s1, params->l);
+        poly = (POLY *)((uint8_t *)ret + sizeof(*ret));
+        vector_init(&ret->t0, poly, params->k);
+        vector_init(&ret->t1, poly + params->k, params->k);
+        vector_init(&ret->s2, poly + 2 * params->k, params->k);
+        vector_init(&ret->s1, poly + 3 * params->k, params->l);
     }
     return ret;
 }
@@ -180,14 +184,15 @@ int ossl_ml_dsa_key_fromdata(ML_DSA_KEY *key, const OSSL_PARAM params[],
 static int public_from_private(ML_DSA_CTX *ctx, const ML_DSA_KEY *key,
                                VECTOR *t1, VECTOR *t0)
 {
+    const ML_DSA_PARAMS *params = ctx->params;
+    POLY polys[ML_DSA_K_MAX + ML_DSA_L_MAX + ML_DSA_K_MAX * ML_DSA_L_MAX];
     MATRIX a_ntt;
     VECTOR s1_ntt;
     VECTOR t;
-    const ML_DSA_PARAMS *params = ctx->params;
 
-    matrix_init(&a_ntt,  params->k, params->l);
-    vector_init(&s1_ntt, params->l);
-    vector_init(&t, params->k);
+    vector_init(&t, polys, params->k);
+    vector_init(&s1_ntt, polys + params->k, params->l);
+    matrix_init(&a_ntt, polys + params->k + params->l, params->k, params->l);
 
     /* Using rho generate A' = A in NTT form */
     if (!ossl_ml_dsa_sample_expandA(ctx->g_ctx, key->rho, &a_ntt))
@@ -214,6 +219,7 @@ int ossl_ml_dsa_key_pairwise_check(const ML_DSA_KEY *key)
     int ret = 0;
     ML_DSA_CTX *ctx = NULL;
     VECTOR t1, t0;
+    POLY polys[ML_DSA_K_MAX * 2];
 
     if (key->pub_encoding == NULL || key->priv_encoding == 0)
         return 0;
@@ -222,8 +228,8 @@ int ossl_ml_dsa_key_pairwise_check(const ML_DSA_KEY *key)
     if (ctx == NULL)
         return 0;
 
-    vector_init(&t1, key->params->k);
-    vector_init(&t0, key->params->k);
+    vector_init(&t1, polys, key->params->k);
+    vector_init(&t0, polys + key->params->k, key->params->k);
     if (!public_from_private(ctx, key, &t1, &t0))
         goto err;
 
index c841da0987d693c1785063c5d0a1ba6d66907fd4..bddba9d62cd77c561028aaf2d5424c46bff92662 100644 (file)
@@ -22,6 +22,14 @@ struct ml_dsa_key_st {
     uint8_t rho[ML_DSA_RHO_BYTES]; /* public random seed */
     uint8_t tr[ML_DSA_TR_BYTES];   /* Pre-cached public key Hash */
     uint8_t K[ML_DSA_K_BYTES];     /* Private random seed for signing */
+
+    /*
+     * The encoded public and private keys, these are non NULL if the key
+     * components are generated or loaded.
+     */
+    uint8_t *pub_encoding;
+    uint8_t *priv_encoding;
+
     /*
      * t0 is the Polynomial encoding of the 13 LSB of each coefficient of the
      * uncompressed public key polynomial t. This is saved as part of the
@@ -35,13 +43,6 @@ struct ml_dsa_key_st {
      * (There are 23 bits in q-modulus.. i.e 10 bits = 23 - 13)
      */
     VECTOR t1;
-    VECTOR s1; /* private secret of size L with short coefficients (-4..4) or (-2..2) */
     VECTOR s2; /* private secret of size K with short coefficients (-4..4) or (-2..2) */
-
-    /*
-     * The encoded public and private keys, these are non NULL if the key
-     * components are generated or loaded.
-     */
-    uint8_t *pub_encoding;
-    uint8_t *priv_encoding;
+    VECTOR s1; /* private secret of size L with short coefficients (-4..4) or (-2..2) */
 };
index 88c6be3bbb1f487d45930314a8a2f4a3933ed458..c145481c056cc846abe165f0939a44c1d2d887bf 100644 (file)
@@ -24,6 +24,7 @@ void ossl_ml_dsa_matrix_mult_vector(const MATRIX *a, const VECTOR *s,
                                     VECTOR *t)
 {
     size_t i, j;
+    POLY *poly = a->m_poly;
 
     vector_zero(t);
 
@@ -31,7 +32,7 @@ void ossl_ml_dsa_matrix_mult_vector(const MATRIX *a, const VECTOR *s,
         for (j = 0; j < a->l; j++) {
             POLY product;
 
-            ossl_ml_dsa_poly_ntt_mult(&a->m_poly[i][j], &s->poly[j], &product);
+            ossl_ml_dsa_poly_ntt_mult(poly++, &s->poly[j], &product);
             poly_add(&product, &t->poly[i], &t->poly[i]);
         }
     }
index 759a69bc71f65594aad583933a5e79961c261cb4..2a79c6c59dd8fa72ffe729003b7c97d1e91bbd47 100644 (file)
@@ -9,15 +9,16 @@
 
 /* A 'k' by 'l' Matrix object ('k' rows and 'l' columns) containing polynomial entries */
 struct matrix_st {
-    POLY m_poly[ML_DSA_K_MAX][ML_DSA_L_MAX];
+    POLY *m_poly;
     size_t k, l;
 };
 
 static ossl_inline ossl_unused void
-matrix_init(MATRIX *m, size_t k, size_t l)
+matrix_init(MATRIX *m, POLY *polys, size_t k, size_t l)
 {
     m->k = k;
     m->l = l;
+    m->m_poly = polys;
 }
 
 void ossl_ml_dsa_matrix_mult_vector(const MATRIX *matrix_kl, const VECTOR *vl,
index a79b058906dc8b02a6e1b8465c074c7007a9e35f..71509337b5894727f316f31c2f849d7bca2a7c08 100644 (file)
@@ -22,8 +22,8 @@
  * The multiplication of a.b mod q requires division by q which is a slow operation.
  *
  * When many multiplications mod q are required montgomery multiplication
- * can be used. This requires a number R > N such that R & N are coprime
- * (i.e. GCD(N, R) = 1), so that division happens using R instead of q.
+ * can be used. This requires a number R > q such that R & q are coprime
+ * (i.e. GCD(R, q) = 1), so that division happens using R instead of q.
  * If r is a power of 2 then this division can be done as a bit shift.
  *
  * Given that q = 2^23 - 2^13 + 1
index 74abdd5d67a2785f06a91aaa5fbad13a20a17900..52015c16b7ec33054c566a4701b8d54ce7dbc7a5 100644 (file)
@@ -188,6 +188,7 @@ int ossl_ml_dsa_sample_expandA(EVP_MD_CTX *g_ctx, const uint8_t *rho,
     int ret = 0;
     size_t i, j;
     uint8_t derived_seed[ML_DSA_RHO_BYTES + 2];
+    POLY *poly = out->m_poly;
 
     /* The seed used for each matrix element is rho + column_index + row_index */
     memcpy(derived_seed, rho, ML_DSA_RHO_BYTES);
@@ -197,8 +198,7 @@ int ossl_ml_dsa_sample_expandA(EVP_MD_CTX *g_ctx, const uint8_t *rho,
             derived_seed[ML_DSA_RHO_BYTES + 1] = (uint8_t)i;
             derived_seed[ML_DSA_RHO_BYTES] = (uint8_t)j;
             /* Generate the polynomial for each matrix element using a unique seed */
-            if (!rej_ntt_poly(g_ctx, derived_seed, sizeof(derived_seed),
-                              &out->m_poly[i][j]))
+            if (!rej_ntt_poly(g_ctx, derived_seed, sizeof(derived_seed), poly++))
                 goto err;
         }
     }
diff --git a/crypto/ml_dsa/ml_dsa_sign.c b/crypto/ml_dsa/ml_dsa_sign.c
new file mode 100644 (file)
index 0000000..adb9435
--- /dev/null
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License 2.0 (the "License").  You may not use
+ * this file except in compliance with the License.  You can obtain a copy
+ * in the file LICENSE in the source distribution or at
+ * https://www.openssl.org/source/license.html
+ */
+
+#include <openssl/core_dispatch.h>
+#include <openssl/core_names.h>
+#include <openssl/params.h>
+#include <openssl/rand.h>
+#include "ml_dsa_local.h"
+#include "ml_dsa_key.h"
+#include "ml_dsa_params.h"
+#include "ml_dsa_matrix.h"
+
+
+/*
+ * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
+ * @returns 1 on success and 0 on failure.
+ */
+template <int K, int L>
+static int ossl_ml_dsa_sign_internal(
+    uint8_t out_encoded_signature[signature_bytes<K>()],
+    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
+    const uint8_t *context_prefix, size_t context_prefix_len,
+    const uint8_t *context, size_t context_len,
+    const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
+  uint8_t mu[kMuBytes];
+  struct BORINGSSL_keccak_st keccak_ctx;
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
+                          sizeof(priv->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+  uint8_t rho_prime[kRhoPrimeBytes];
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
+  BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
+                          MLDSA_SIGNATURE_RANDOMIZER_BYTES);
+  BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, kRhoPrimeBytes);
+
+  // Intermediate values, allocated on the heap to allow use when there is a
+  // limited amount of stack.
+  struct values_st {
+    struct signature<K, L> sign;
+    vector<L> s1_ntt;
+    vector<K> s2_ntt;
+    vector<K> t0_ntt;
+    matrix<K, L> a_ntt;
+    vector<L> y;
+    vector<K> w;
+    vector<K> w1;
+    vector<L> cs1;
+    vector<K> cs2;
+  };
+  std::unique_ptr<values_st, DeleterFree<values_st>> values(
+      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+  if (values == NULL) {
+    return 0;
+  }
+  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+  vector_ntt(&values->s1_ntt);
+
+  OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
+  vector_ntt(&values->s2_ntt);
+
+  OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
+  vector_ntt(&values->t0_ntt);
+
+  matrix_expand(&values->a_ntt, priv->rho);
+
+  // kappa must not exceed 2**16/L = 13107. But the probability of it
+  // exceeding even 1000 iterations is vanishingly small.
+  for (size_t kappa = 0;; kappa += L) {
+    vector_expand_mask(&values->y, rho_prime, kappa);
+
+    vector<L> *y_ntt = &values->cs1;
+    OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
+    vector_ntt(y_ntt);
+
+    matrix_mult(&values->w, &values->a_ntt, y_ntt);
+    vector_inverse_ntt(&values->w);
+
+    vector_high_bits(&values->w1, &values->w);
+    uint8_t w1_encoded[128 * K];
+    w1_encode(w1_encoded, &values->w1);
+
+    BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+    BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+    BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
+    BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
+                             2 * lambda_bytes<K>());
+
+    scalar c_ntt;
+    scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
+                                  sizeof(values->sign.c_tilde), tau<K>());
+    scalar_ntt(&c_ntt);
+
+    vector_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
+    vector_inverse_ntt(&values->cs1);
+    vector_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
+    vector_inverse_ntt(&values->cs2);
+
+    vector_add(&values->sign.z, &values->y, &values->cs1);
+
+    vector<K> *r0 = &values->w1;
+    vector_sub(r0, &values->w, &values->cs2);
+    vector_low_bits(r0, r0);
+
+    // Leaking the fact that a signature was rejected is fine as the next
+    // attempt at a signature will be (indistinguishable from) independent of
+    // this one. Note, however, that we additionally leak which of the two
+    // branches rejected the signature. Section 5.5 of
+    // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
+    // describes this leak as OK. Note we leak less than what is described by
+    // the paper; we do not reveal which coefficient violated the bound, and
+    // we hide which of the |z_max| or |r0_max| bound failed. See also
+    // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
+    uint32_t z_max = vector_max(&values->sign.z);
+    uint32_t r0_max = vector_max_signed(r0);
+    if (constant_time_declassify_w(
+            constant_time_ge_w(z_max, gamma1<K>() - beta<K>()) |
+            constant_time_ge_w(r0_max, kGamma2 - beta<K>()))) {
+      continue;
+    }
+
+    vector<K> *ct0 = &values->w1;
+    vector_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
+    vector_inverse_ntt(ct0);
+    vector_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
+
+    // See above.
+    uint32_t ct0_max = vector_max(ct0);
+    size_t h_ones = vector_count_ones(&values->sign.h);
+    if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
+                                   constant_time_lt_w(omega<K>(), h_ones))) {
+      continue;
+    }
+
+    // Although computed with the private key, the signature is public.
+    CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
+    CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
+    CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
+
+    CBB cbb;
+    CBB_init_fixed(&cbb, out_encoded_signature, signature_bytes<K>());
+    if (!mldsa_marshal_signature(&cbb, &values->sign)) {
+      return 0;
+    }
+
+    BSSL_CHECK(CBB_len(&cbb) == signature_bytes<K>());
+    return 1;
+  }
+}
+
+
+
index c21122add18a4d199e1f660adf0edd797059d9a7..5aa84511c2e3287927255927d1f29380f912dfe0 100644 (file)
@@ -9,16 +9,16 @@
 
 #include "ml_dsa_poly.h"
 
-/* Either a 1 * l column vector or a k * 1 row vector of polynomial entries */
 struct vector_st {
-    POLY poly[ML_DSA_K_MAX];
-    size_t num_poly; /* Either k or l */
+    POLY *poly;
+    size_t num_poly;
 };
 
 /* @brief Set the number of polynomial elements that will be present in the vector */
 static ossl_inline ossl_unused
-void vector_init(VECTOR *v, size_t num_polys)
+void vector_init(VECTOR *v, POLY *polys, size_t num_polys)
 {
+    v->poly = polys;
     v->num_poly = num_polys;
 }
 
@@ -64,7 +64,7 @@ static ossl_inline ossl_unused void
 vector_copy(VECTOR *dst, const VECTOR *src)
 {
     dst->num_poly = src->num_poly;
-    memcpy(dst->poly, src->poly, sizeof(src->poly));
+    memcpy(dst->poly, src->poly, src->num_poly * sizeof(src->poly[0]));
 }
 
 /* @brief return 1 if 2 vectors are equal, or 0 otherwise */