]> git.ipfire.org Git - thirdparty/strongswan.git/commitdiff
tls-hkdf: Store OKM in local variables
authorTobias Brunner <tobias@strongswan.org>
Mon, 31 Aug 2020 15:08:07 +0000 (17:08 +0200)
committerTobias Brunner <tobias@strongswan.org>
Fri, 12 Feb 2021 10:45:44 +0000 (11:45 +0100)
src/libtls/tls_hkdf.c

index 97ba35638bdc7ab72469c2f023d0c3c5d5751387..8ec55b85835b6b5bb79cec6ddb53cb49f2ac57ef 100644 (file)
@@ -66,11 +66,6 @@ struct private_tls_hkdf_t {
         */
        chunk_t prk;
 
-       /**
-        * OKM used.
-        */
-       chunk_t okm;
-
        /**
         * Current implementation needs a copy of derived secrets to calculate the
         * proper finished key.
@@ -131,12 +126,10 @@ static bool expand(private_tls_hkdf_t *this, chunk_t prk, chunk_t info,
                return FALSE;
        }
        prf_plus = prf_plus_create(this->prf, TRUE, info);
-       chunk_clear(okm);
        if (!prf_plus || !prf_plus->allocate_bytes(prf_plus, length, okm))
        {
                DBG1(DBG_TLS, "unable to allocate PRF+ result");
                DESTROY_IF(prf_plus);
-               chunk_clear(okm);
                return FALSE;
        }
        prf_plus->destroy(prf_plus);
@@ -177,7 +170,7 @@ static bool expand_label(private_tls_hkdf_t *this, chunk_t secret,
  * Derive-Secret(Secret, Label, Message) -> OKM
  */
 static bool derive_secret(private_tls_hkdf_t *this, chunk_t label,
-                                                 chunk_t messages)
+                                                 chunk_t messages, chunk_t *okm)
 {
        chunk_t context;
        bool success;
@@ -188,8 +181,7 @@ static bool derive_secret(private_tls_hkdf_t *this, chunk_t label,
        }
 
        success = expand_label(this, this->prk, label, context,
-                                                  this->hasher->get_hash_size(this->hasher),
-                                                  &this->okm);
+                                                  this->hasher->get_hash_size(this->hasher), okm);
        chunk_free(&context);
        return success;
 }
@@ -259,7 +251,7 @@ static bool move_to_phase_1(private_tls_hkdf_t *this)
  */
 static bool move_to_phase_2(private_tls_hkdf_t *this)
 {
-       chunk_t derived;
+       chunk_t derived, okm;
 
        switch (this->phase)
        {
@@ -272,7 +264,7 @@ static bool move_to_phase_2(private_tls_hkdf_t *this)
                        /* fall-through */
                case HKDF_PHASE_1:
                        derived = chunk_from_str("tls13 derived");
-                       if (!derive_secret(this, derived, chunk_empty))
+                       if (!derive_secret(this, derived, chunk_empty, &okm))
                        {
                                DBG1(DBG_TLS, "unable to derive secret");
                                return FALSE;
@@ -281,14 +273,17 @@ static bool move_to_phase_2(private_tls_hkdf_t *this)
                        if (!this->shared_secret.ptr)
                        {
                                DBG1(DBG_TLS, "no shared secret set");
+                               chunk_clear(&okm);
                                return FALSE;
                        }
 
-                       if (!extract(this, this->okm, this->shared_secret, &this->prk))
+                       if (!extract(this, okm, this->shared_secret, &this->prk))
                        {
                                DBG1(DBG_TLS, "unable extract PRK");
+                               chunk_clear(&okm);
                                return FALSE;
                        }
+                       chunk_clear(&okm);
                        this->phase = HKDF_PHASE_2;
                        return TRUE;
                case HKDF_PHASE_2:
@@ -325,7 +320,7 @@ static bool move_to_phase_2(private_tls_hkdf_t *this)
  */
 static bool move_to_phase_3(private_tls_hkdf_t *this)
 {
-       chunk_t derived, ikm_zero;
+       chunk_t derived, okm, ikm_zero;
 
        switch (this->phase)
        {
@@ -340,7 +335,7 @@ static bool move_to_phase_3(private_tls_hkdf_t *this)
                case HKDF_PHASE_2:
                        /* prepare okm for next extract */
                        derived = chunk_from_str("tls13 derived");
-                       if (!derive_secret(this, derived, chunk_empty))
+                       if (!derive_secret(this, derived, chunk_empty, &okm))
                        {
                                DBG1(DBG_TLS, "unable to derive secret");
                                return FALSE;
@@ -348,11 +343,13 @@ static bool move_to_phase_3(private_tls_hkdf_t *this)
 
                        ikm_zero = chunk_alloca(this->hasher->get_hash_size(this->hasher));
                        chunk_copy_pad(ikm_zero, chunk_empty, 0);
-                       if (!extract(this, this->okm, ikm_zero, &this->prk))
+                       if (!extract(this, okm, ikm_zero, &this->prk))
                        {
                                DBG1(DBG_TLS, "unable extract PRK");
+                               chunk_clear(&okm);
                                return FALSE;
                        }
+                       chunk_clear(&okm);
                        this->phase = HKDF_PHASE_3;
                        return TRUE;
                case HKDF_PHASE_3:
@@ -373,6 +370,8 @@ METHOD(tls_hkdf_t, generate_secret, bool,
        private_tls_hkdf_t *this, tls_hkdf_label_t label, chunk_t messages,
        chunk_t *secret)
 {
+       chunk_t okm;
+
        switch (label)
        {
                case TLS_HKDF_EXT_BINDER:
@@ -427,7 +426,7 @@ METHOD(tls_hkdf_t, generate_secret, bool,
 
                if (!expand_label(this, previous, chunk_from_str("tls13 traffic upd"),
                                                  chunk_empty, this->hasher->get_hash_size(this->hasher),
-                                                 &this->okm))
+                                                 &okm))
                {
                        DBG1(DBG_TLS, "unable to update secret");
                        return FALSE;
@@ -435,7 +434,8 @@ METHOD(tls_hkdf_t, generate_secret, bool,
        }
        else
        {
-               if (!derive_secret(this, chunk_from_str(hkdf_labels[label]), messages))
+               if (!derive_secret(this, chunk_from_str(hkdf_labels[label]), messages,
+                                                  &okm))
                {
                        DBG1(DBG_TLS, "unable to derive secret");
                        return FALSE;
@@ -448,13 +448,13 @@ METHOD(tls_hkdf_t, generate_secret, bool,
                case TLS_HKDF_C_AP_TRAFFIC:
                case TLS_HKDF_UPD_C_TRAFFIC:
                        chunk_clear(&this->client_traffic_secret);
-                       this->client_traffic_secret = chunk_clone(this->okm);
+                       this->client_traffic_secret = chunk_clone(okm);
                        break;
                case TLS_HKDF_S_HS_TRAFFIC:
                case TLS_HKDF_S_AP_TRAFFIC:
                case TLS_HKDF_UPD_S_TRAFFIC:
                        chunk_clear(&this->server_traffic_secret);
-                       this->server_traffic_secret = chunk_clone(this->okm);
+                       this->server_traffic_secret = chunk_clone(okm);
                        break;
                default:
                        break;
@@ -462,7 +462,11 @@ METHOD(tls_hkdf_t, generate_secret, bool,
 
        if (secret)
        {
-               *secret = chunk_clone(this->okm);
+               *secret = okm;
+       }
+       else
+       {
+               chunk_clear(&okm);
        }
        return TRUE;
 }
@@ -533,7 +537,6 @@ METHOD(tls_hkdf_t, destroy, void,
        chunk_clear(&this->psk);
        chunk_clear(&this->prk);
        chunk_clear(&this->shared_secret);
-       chunk_clear(&this->okm);
        chunk_clear(&this->client_traffic_secret);
        chunk_clear(&this->server_traffic_secret);
        DESTROY_IF(this->prf);