]> git.ipfire.org Git - thirdparty/strongswan.git/commitdiff
ml: Store decoded public/private key and matrix A on initiator
authorTobias Brunner <tobias@strongswan.org>
Mon, 28 Oct 2024 14:12:32 +0000 (15:12 +0100)
committerTobias Brunner <tobias@strongswan.org>
Fri, 22 Nov 2024 13:14:10 +0000 (14:14 +0100)
While this does require quite a bit of memory, on initiators there are
usually fewer concurrent SAs getting created so this should be less of
an issue than on a gateway that handles lots of SAs as responder.

The speed up is about 30% on the initiator during the decapsulation,
while the key generation does take a bit more time (about 3%).

src/libstrongswan/plugins/ml/ml_kem.c

index 94bf4daa64b674cc9dbd79b496d61822d072d84d..22cf305189ebbbde328653c5e992e9fbdfd0b905 100644 (file)
@@ -43,10 +43,22 @@ struct private_key_exchange_t {
        const ml_kem_params_t *params;
 
        /**
-        * Decryption/private key as initiator.
+        * Decryption/private key as initiator (array of k polynomials).
         */
        chunk_t private_key;
 
+       /**
+        * Encryption/public key and matrix A as initiator (array of k polynomials,
+        * followed by a matrix of k*k polynomials).
+        */
+       chunk_t public_key;
+
+       /**
+        * Additional key data as initiator (hash of encoded public key,
+        * rejection seed z).
+        */
+       chunk_t key_data;
+
        /**
         * Ciphertext as responder.
         */
@@ -507,12 +519,12 @@ static void poly_to_message(ml_poly_t *p, uint8_t *m)
 }
 
 /**
- * Generate a key pair from the given random seed d.
+ * Generate a key pair from the given random seed d.  Returns the encoded public
+ * key.
  *
  * Algorithm 13 in FIPS 203.
  */
-static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
-                                          chunk_t *dk)
+static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek)
 {
        const uint8_t k = this->params->k;
        const uint8_t eta1 = this->params->eta1;
@@ -521,7 +533,7 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
        uint8_t *rho = seeds;
        uint8_t *sigma = seeds + ML_KEM_SEED_LEN;
        uint8_t N = 0;
-       ml_poly_t a[k*k], s[k], e[k], t[k];
+       ml_poly_t *a, *s, e[k], *t;
        int i;
        bool success = FALSE;
 
@@ -533,12 +545,19 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
                goto err;
        }
 
+       this->public_key = chunk_alloc((k+1) * k * sizeof(ml_poly_t));
+       t = (ml_poly_t*)this->public_key.ptr;
+       a = (ml_poly_t*)this->public_key.ptr + k;
+
        /* generate matrix A */
        if (!generate_a(this, a, rho))
        {
                goto err;
        }
 
+       this->private_key = chunk_alloc(k * sizeof(ml_poly_t));
+       s = (ml_poly_t*)this->private_key.ptr;
+
        /* sample s from CBD using noise seed sigma and nonce N as input */
        for (i = 0; i < k; i++)
        {
@@ -575,16 +594,11 @@ static bool pke_keygen(private_key_exchange_t *this, chunk_t d, chunk_t *ek,
        encode_poly_arr(k, t, ek->ptr);
        memcpy(ek->ptr + k * ML_KEM_POLY_LEN, rho, ML_KEM_SEED_LEN);
 
-       /* pack private key */
-       *dk = chunk_alloc(k * ML_KEM_POLY_LEN);
-       encode_poly_arr(k, s, dk->ptr);
-
        success = TRUE;
 
 err:
        memwipe(seeds, sizeof(seeds));
        memwipe(sigma, ML_KEM_SEED_LEN);
-       memwipe(s, sizeof(s));
        memwipe(e, sizeof(e));
        return success;
 }
@@ -606,18 +620,28 @@ static bool pke_encrypt(private_key_exchange_t *this, chunk_t ek, uint8_t *m,
 
        uint8_t rho[ML_KEM_SEED_LEN];
        uint8_t N = 0;
-       ml_poly_t a[k*k], t[k], y[k], e1[k], e2, u[k], mu, v;
+       ml_poly_t a_gen[k*k], *a = a_gen, t_dec[k], *t = t_dec;
+       ml_poly_t y[k], e1[k], e2, u[k], mu, v;
        int i;
        bool success = FALSE;
 
-       /* decode polynomial t and extract seed rho from the public key */
-       decode_poly_arr(k, ek.ptr, t);
-       memcpy(rho, ek.ptr + k * ML_KEM_POLY_LEN, ML_KEM_SEED_LEN);
+       if (!this->public_key.ptr)
+       {
+               /* decode polynomial t and extract seed rho from the public key */
+               decode_poly_arr(k, ek.ptr, t);
+               memcpy(rho, ek.ptr + k * ML_KEM_POLY_LEN, ML_KEM_SEED_LEN);
 
-       /* generate matrix A */
-       if (!generate_a(this, a, rho))
+               /* generate matrix A */
+               if (!generate_a(this, a, rho))
+               {
+                       goto err;
+               }
+       }
+       else
        {
-               goto err;
+               /* as initiator, we already have the decoded polynomial and matrix A */
+               t = (ml_poly_t*)this->public_key.ptr;
+               a = (ml_poly_t*)this->public_key.ptr + k;
        }
 
        /* sample y from CBD using noise seed r and nonce N as input */
@@ -675,26 +699,26 @@ err:
 }
 
 /**
- * Decrypt message m using the given private key and ciphertext.
+ * Decrypt message m using the stored private key and given ciphertext.
  *
  * Algorithm 14 in FIPS 203.
  */
-static bool pke_decrypt(private_key_exchange_t *this, chunk_t dk,
-                                               chunk_t ciphertext, uint8_t *m)
+static bool pke_decrypt(private_key_exchange_t *this, chunk_t ciphertext,
+                                               uint8_t *m)
 {
        const uint8_t k = this->params->k;
        const uint8_t du = this->params->du;
        const uint8_t dv = this->params->dv;
 
-       ml_poly_t s[k], u[k], v, w;
+       ml_poly_t *s, u[k], v, w;
        int i;
 
        /* decode u and v from c1 and c2, the two parts of the ciphertext */
        decompress_poly_arr(k, du, ciphertext.ptr, u);
        decompress_poly_arr(1, dv, ciphertext.ptr + k * du * ML_KEM_N / 8, &v);
 
-       /* decode polynomial s from private key */
-       decode_poly_arr(k, dk.ptr, s);
+       /* we already have private key s stored */
+       s = (ml_poly_t*)this->private_key.ptr;
 
        /* calculate w = v - NTT^-1(s * NTT(u)) */
        for (i = 0; i < k; i++)
@@ -707,9 +731,6 @@ static bool pke_decrypt(private_key_exchange_t *this, chunk_t dk,
 
        /* decode plaintext message m from polynomial w */
        poly_to_message(&w, m);
-
-       memwipe(s, sizeof(s));
-       memwipe(&w, sizeof(w));
        return TRUE;
 }
 
@@ -723,7 +744,7 @@ static bool generate_keypair(private_key_exchange_t *this, chunk_t *ek)
        uint8_t dz[2*ML_KEM_SEED_LEN];
        chunk_t d = chunk_create(dz, ML_KEM_SEED_LEN);
        chunk_t z = chunk_create(dz + ML_KEM_SEED_LEN, ML_KEM_SEED_LEN);
-       chunk_t dk = chunk_empty, Hek;
+       chunk_t Hek;
        bool success = FALSE;
 
        /* get random seeds d and z */
@@ -732,17 +753,16 @@ static bool generate_keypair(private_key_exchange_t *this, chunk_t *ek)
                return FALSE;
        }
 
-       /* generate a key pair and store the private key, the public key, a hash
-        * of the latter and seed z as our secret key */
-       if (pke_keygen(this, d, ek, &dk) &&
+       /* generate a key pair and generate a hash of the latter to be stored
+        * together with the rejection seed z */
+       if (pke_keygen(this, d, ek) &&
                this->H->allocate_hash(this->H, *ek, &Hek))
        {
-               this->private_key = chunk_cat("ccmc", dk, *ek, Hek, z);
+               this->key_data = chunk_cat("mc", Hek, z);
                success = TRUE;
        }
 
        memwipe(dz, sizeof(dz));
-       chunk_clear(&dk);
        return success;
 }
 
@@ -769,25 +789,21 @@ METHOD(key_exchange_t, get_public_key, bool,
  */
 static bool decaps_shared_secret(private_key_exchange_t *this, chunk_t ciphertext)
 {
-       const uint8_t k = this->params->k;
-
-       chunk_t dk, ek, Hek, z, zc, c = chunk_empty;
+       chunk_t Hek, z, zc, c = chunk_empty;
        chunk_t m = chunk_alloca(ML_KEM_SEED_LEN);
        uint8_t Kr[2*ML_KEM_SEED_LEN];
        uint8_t *r = Kr + ML_KEM_SEED_LEN;
        bool success = FALSE;
 
-       /* get the private and public keys, a hash of the latter and seed z */
-       chunk_split(this->private_key, "mmmm",
-                               k * ML_KEM_POLY_LEN, &dk,
-                               k * ML_KEM_POLY_LEN + ML_KEM_SEED_LEN, &ek,
+       /* get the hash of the encoded public key and seed z */
+       chunk_split(this->key_data, "mm",
                                ML_KEM_SEED_LEN, &Hek,
                                ML_KEM_SEED_LEN, &z);
        /* prepare the seed to derive the implicit rejection secret */
        zc = chunk_cat("cc", z, ciphertext);
 
        /* decrypt message m */
-       if (!pke_decrypt(this, dk, ciphertext, m.ptr))
+       if (!pke_decrypt(this, ciphertext, m.ptr))
        {
                goto err;
        }
@@ -801,7 +817,7 @@ static bool decaps_shared_secret(private_key_exchange_t *this, chunk_t ciphertex
 
        /* encrypt the decrypted message again using the derived r */
        c = chunk_alloc(this->params->ct_len);
-       if (!pke_encrypt(this, ek, m.ptr, r, c))
+       if (!pke_encrypt(this, chunk_empty, m.ptr, r, c))
        {
                goto err;
        }
@@ -936,8 +952,10 @@ METHOD(key_exchange_t, set_seed, bool,
 METHOD(key_exchange_t, destroy, void,
        private_key_exchange_t *this)
 {
-       chunk_clear(&this->shared_secret);
        chunk_clear(&this->private_key);
+       chunk_clear(&this->key_data);
+       chunk_clear(&this->shared_secret);
+       chunk_free(&this->public_key);
        chunk_free(&this->ciphertext);
        DESTROY_IF(this->drbg);
        DESTROY_IF(this->shake128);