]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Use secure memory allocation for ML-KEM and ML-DSA private key storage areas
authorDaniel Frink <daniel.frink@ibm.com>
Tue, 13 May 2025 20:27:05 +0000 (15:27 -0500)
committerTomas Mraz <tomas@openssl.org>
Mon, 7 Jul 2025 13:40:47 +0000 (15:40 +0200)
Resolves: #27603

Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/27625)

crypto/ml_dsa/ml_dsa_encoders.c
crypto/ml_dsa/ml_dsa_key.c
crypto/ml_dsa/ml_dsa_vector.h
crypto/ml_kem/ml_kem.c
providers/implementations/encode_decode/ml_kem_codecs.c
providers/implementations/keymgmt/ml_kem_kmgmt.c.in

index 2dca3ae060ab2b968712c22e9c6f9b737c29919b..83971bc1f7dec4b12cff195a753381beedfc4f40 100644 (file)
@@ -713,7 +713,7 @@ int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key)
     size_t enc_len = params->sk_len;
     const POLY *t0 = key->t0.poly;
     WPACKET pkt;
-    uint8_t *enc = OPENSSL_malloc(enc_len);
+    uint8_t *enc = OPENSSL_secure_malloc(enc_len);
 
     if (enc == NULL)
         return 0;
@@ -741,7 +741,7 @@ int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key)
     if (!WPACKET_get_total_written(&pkt, &written)
             || written != enc_len)
         goto err;
-    OPENSSL_clear_free(key->priv_encoding, enc_len);
+    OPENSSL_secure_clear_free(key->priv_encoding, enc_len);
     key->priv_encoding = enc;
     ret = 1;
 err:
@@ -805,9 +805,12 @@ int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
             goto err;
     if (PACKET_remaining(&pkt) != 0)
         goto err;
-    if (key->priv_encoding == NULL
-        && (key->priv_encoding = OPENSSL_memdup(in, in_len)) == NULL)
-        goto err;
+    if (key->priv_encoding == NULL) {
+        key->priv_encoding = OPENSSL_secure_malloc(in_len);
+        if (key->priv_encoding == NULL)
+            goto err;
+        memcpy(key->priv_encoding, in, in_len);
+    }
     /*
      * Computing the public key also computes its hash, which must be equal to
      * the |tr| value in the private key, else the key was corrupted.
index 0deeb7f6bb2e854ccf7afd4713ccd81d109bef72..28f4106cfd5ceeba45d5fd97bf7ffb2c5de2578a 100644 (file)
@@ -48,9 +48,13 @@ int ossl_ml_dsa_set_prekey(ML_DSA_KEY *key, int flags_set, int flags_clr,
         || key->seed != NULL)
         return 0;
 
-    if (sk != NULL
-        && (key->priv_encoding = OPENSSL_memdup(sk, sk_len)) == NULL)
-        goto end;
+    if (sk != NULL) {
+        key->priv_encoding = OPENSSL_secure_malloc(sk_len);
+        if (key->priv_encoding == NULL)
+            goto end;
+        memcpy(key->priv_encoding, sk, sk_len);
+    }
+
     if (seed != NULL
         && (key->seed = OPENSSL_memdup(seed, seed_len)) == NULL)
         goto end;
@@ -60,7 +64,7 @@ int ossl_ml_dsa_set_prekey(ML_DSA_KEY *key, int flags_set, int flags_clr,
 
  end:
     if (!ret) {
-        OPENSSL_free(key->priv_encoding);
+        OPENSSL_secure_free(key->priv_encoding);
         OPENSSL_free(key->seed);
         key->priv_encoding = key->seed = NULL;
     }
@@ -114,7 +118,7 @@ int ossl_ml_dsa_key_priv_alloc(ML_DSA_KEY *key)
 
     if (key->s1.poly != NULL)
         return 0;
-    if (!vector_alloc(&key->s1, l + 2 * k))
+    if (!vector_secure_alloc(&key->s1, l + 2 * k))
         return 0;
 
     poly = key->s1.poly;
@@ -151,7 +155,7 @@ void ossl_ml_dsa_key_reset(ML_DSA_KEY *key)
         vector_zero(&key->s1);
         vector_zero(&key->s2);
         vector_zero(&key->t0);
-        vector_free(&key->s1);
+        vector_secure_free(&key->s1);
         key->s2.poly = NULL;
         key->t0.poly = NULL;
     }
@@ -161,7 +165,7 @@ void ossl_ml_dsa_key_reset(ML_DSA_KEY *key)
     OPENSSL_free(key->pub_encoding);
     key->pub_encoding = NULL;
     if (key->priv_encoding != NULL)
-        OPENSSL_clear_free(key->priv_encoding, key->params->sk_len);
+        OPENSSL_secure_clear_free(key->priv_encoding, key->params->sk_len);
     key->priv_encoding = NULL;
     if (key->seed != NULL)
         OPENSSL_clear_free(key->seed, ML_DSA_SEED_BYTES);
@@ -217,10 +221,10 @@ ML_DSA_KEY *ossl_ml_dsa_key_dup(const ML_DSA_KEY *src, int selection)
                         vector_copy(&ret->s2, &src->s2);
                         vector_copy(&ret->t0, &src->t0);
                     }
-                    if ((ret->priv_encoding =
-                            OPENSSL_memdup(src->priv_encoding,
-                                           src->params->sk_len)) == NULL)
+                    ret->priv_encoding = OPENSSL_secure_malloc(src->params->sk_len);
+                    if (!ret->priv_encoding)
                         goto err;
+                    memcpy(ret->priv_encoding, src->priv_encoding, src->params->sk_len);
                 }
                 if (src->seed != NULL
                     && (ret->seed = OPENSSL_memdup(src->seed,
index 125e3257e45e1684cc6d912627c69509ee4df4a8..1aa7bb40828e092922f0957cd2edc0f7331cd7be 100644 (file)
@@ -40,6 +40,16 @@ int vector_alloc(VECTOR *v, size_t num_polys)
     return 1;
 }
 
+static ossl_inline ossl_unused
+int vector_secure_alloc(VECTOR *v, size_t num_polys)
+{
+    v->poly = OPENSSL_secure_malloc(num_polys * sizeof(POLY));
+    if (v->poly == NULL)
+        return 0;
+    v->num_poly = num_polys;
+    return 1;
+}
+
 static ossl_inline ossl_unused
 void vector_free(VECTOR *v)
 {
@@ -48,6 +58,14 @@ void vector_free(VECTOR *v)
     v->num_poly = 0;
 }
 
+static ossl_inline ossl_unused
+void vector_secure_free(VECTOR *v)
+{
+    OPENSSL_secure_clear_free(v->poly, v->num_poly * sizeof(POLY));
+    v->poly = NULL;
+    v->num_poly = 0;
+}
+
 /* @brief zeroize a vectors polynomial coefficients */
 static ossl_inline ossl_unused
 void vector_zero(VECTOR *va)
index 56a4bcf2fd11631d32d52a80da313974724ffd62..0446675c16f5786105b2b822345c4b3e394217e2 100644 (file)
@@ -1582,11 +1582,18 @@ ossl_ml_kem_key_reset(ML_KEM_KEY *key)
      *   secret |z|, and seed |d|, we can cleanse all three in one call.
      *
      * - Otherwise, when key->d is set, cleanse the stashed seed.
+     *
+     * If the memory has been allocated with secure memory, it will be cleared
+     * before being free'd under the OPENSSL_secure_free call.
      */
-    if (ossl_ml_kem_have_prvkey(key))
-        OPENSSL_cleanse(key->s,
-                        key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
-    OPENSSL_free(key->t);
+    if (ossl_ml_kem_have_prvkey(key)) {
+        if (!CRYPTO_secure_allocated(key->t))
+            OPENSSL_cleanse(key->s, key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
+        OPENSSL_secure_free(key->t);
+    } else {
+        OPENSSL_free(key->t);
+    }
+
     key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
 }
 
@@ -1653,6 +1660,7 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
 {
     int ok = 0;
     ML_KEM_KEY *ret;
+    void *tmp;
 
     /*
      * Partially decoded keys, not yet imported or loaded, should never be
@@ -1683,7 +1691,11 @@ ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
         ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
         break;
     case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
-        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
+        tmp = OPENSSL_secure_malloc(key->vinfo->prvalloc);
+        if (tmp == NULL)
+            break;
+        memcpy(tmp, key->t, key->vinfo->prvalloc);
+        ok = add_storage(tmp, 1, ret);
         /* Duplicated keys retain |d|, if available */
         if (key->d != NULL)
             ret->d = ret->z + ML_KEM_RANDOM_BYTES;
@@ -1715,10 +1727,8 @@ void ossl_ml_kem_key_free(ML_KEM_KEY *key)
 
     if (ossl_ml_kem_decoded_key(key)) {
         OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
-        if (ossl_ml_kem_have_dkenc(key)) {
-            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
-            OPENSSL_free(key->encoded_dk);
-        }
+        if (ossl_ml_kem_have_dkenc(key))
+            OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
     }
     ossl_ml_kem_key_reset(key);
     OPENSSL_free(key);
@@ -1831,7 +1841,7 @@ int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
         || (mdctx = EVP_MD_CTX_new()) == NULL)
         return 0;
 
-    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
+    if (add_storage(OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
         ret = parse_prvkey(in, mdctx, key);
 
     if (!ret)
@@ -1878,7 +1888,7 @@ int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
      */
     CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
 
-    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
+    if (add_storage(OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
         ret = genkey(seed, mdctx, pubenc, key);
     OPENSSL_cleanse(seed, sizeof(seed));
 
index f2ade8790969c30bf593794c0d272f830143f1a4..fdf6e7bf478b8a7d1977781d393449dc28f3fa40 100644 (file)
@@ -254,7 +254,7 @@ ossl_ml_kem_d2i_PKCS8(const uint8_t *prvenc, int prvlen,
         }
     }
     if (p8fmt->priv_length > 0) {
-        if ((key->encoded_dk = OPENSSL_malloc(p8fmt->priv_length)) == NULL) {
+        if ((key->encoded_dk = OPENSSL_secure_malloc(p8fmt->priv_length)) == NULL) {
             ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
                            "error parsing %s private key",
                            v->algorithm_name);
index 12d9baa2c1df554ae63a8ea4e1105aac5e9b8c43..9a13d7a8ce56fa674e28493252a4390bb4d74319 100644 (file)
@@ -537,12 +537,12 @@ void *ml_kem_load(const void *reference, size_t reference_sz)
             if (!ml_kem_pairwise_test(key, key->prov_flags))
                 goto err;
         }
-        OPENSSL_free(encoded_dk);
+        OPENSSL_secure_clear_free(encoded_dk, key->vinfo->prvkey_bytes);
         return key;
     }
 
  err:
-    OPENSSL_free(encoded_dk);
+    OPENSSL_secure_clear_free(encoded_dk, key->vinfo->prvkey_bytes);
     ossl_ml_kem_key_free(key);
     return NULL;
 }