]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
handshake: introduced gnutls_session_key_update()
authorNikos Mavrogiannopoulos <nmav@redhat.com>
Thu, 19 Oct 2017 14:27:30 +0000 (16:27 +0200)
committerNikos Mavrogiannopoulos <nmav@redhat.com>
Mon, 19 Feb 2018 14:29:36 +0000 (15:29 +0100)
This function allows updating keys of the session and notifying
the peer.

Signed-off-by: Nikos Mavrogiannopoulos <nmav@redhat.com>
lib/constate.c
lib/gnutls_int.h
lib/handshake-tls13.c
lib/includes/gnutls/gnutls.h.in
lib/libgnutls.map
lib/record.c
lib/state.c
lib/tls13/key_update.c
lib/tls13/key_update.h

index 88d0e17e09317c3458654a32918bcd5d3db71096..7e2634365aba8666b2e65cb5663c42c38ac75a30 100644 (file)
@@ -201,16 +201,70 @@ _gnutls_set_keys(gnutls_session_t session, record_parameters_st * params,
 
 static int
 _tls13_update_keys(gnutls_session_t session, hs_stage_t stage,
-                  uint16_t epoch, record_parameters_st *params,
+                  record_parameters_st *params,
                   unsigned iv_size, unsigned key_size)
 {
        uint8_t key_block[MAX_CIPHER_KEY_SIZE];
        uint8_t iv_block[MAX_CIPHER_IV_SIZE];
        char buf[65];
-       record_state_st *state;
-       uint16_t *session_epoch;
+       record_state_st *upd_state;
+       record_parameters_st *prev;
        int ret;
 
+       /* generate new keys for direction needed and copy old from previous epoch */
+
+       if (stage == STAGE_UPD_OURS) {
+               upd_state = &params->write;
+
+               ret = _gnutls_epoch_get(session, EPOCH_READ_CURRENT, &prev);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               params->read.sequence_number = prev->read.sequence_number;
+               ret = _gnutls_set_datum(&params->read.key, prev->read.key.data, prev->read.key.size);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               _gnutls_hard_log("INT: READ KEY [%d]: %s\n",
+                                params->read.key.size,
+                                _gnutls_bin2hex(params->read.key.data, params->read.key.size,
+                                                buf, sizeof(buf), NULL));
+
+               ret = _gnutls_set_datum(&params->read.IV, prev->read.IV.data, prev->read.IV.size);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               _gnutls_hard_log("INT: READ IV [%d]: %s\n",
+                                params->read.IV.size,
+                                _gnutls_bin2hex(params->read.IV.data, params->read.IV.size,
+                                                buf, sizeof(buf), NULL));
+       } else {
+               upd_state = &params->read;
+
+               ret = _gnutls_epoch_get(session, EPOCH_WRITE_CURRENT, &prev);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               params->write.sequence_number = prev->write.sequence_number;
+               ret = _gnutls_set_datum(&params->write.key, prev->write.key.data, prev->write.key.size);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               _gnutls_hard_log("INT: WRITE KEY [%d]: %s\n",
+                                params->write.key.size,
+                                _gnutls_bin2hex(params->write.key.data, params->write.key.size,
+                                                buf, sizeof(buf), NULL));
+
+               ret = _gnutls_set_datum(&params->write.IV, prev->write.IV.data, prev->write.IV.size);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               _gnutls_hard_log("INT: WRITE IV [%d]: %s\n",
+                                params->write.IV.size,
+                                _gnutls_bin2hex(params->write.IV.data, params->write.IV.size,
+                                                buf, sizeof(buf), NULL));
+       }
+
        if ((session->security_parameters.entity == GNUTLS_CLIENT && stage == STAGE_UPD_OURS) ||
            (session->security_parameters.entity == GNUTLS_SERVER && stage == STAGE_UPD_PEERS)) {
                /* client keys */
@@ -229,15 +283,6 @@ _tls13_update_keys(gnutls_session_t session, hs_stage_t stage,
                ret = _tls13_expand_secret(session, "iv", 2, NULL, 0, session->key.hs_ckey, iv_size, iv_block);
                if (ret < 0)
                        return gnutls_assert_val(ret);
-
-               if (stage == STAGE_UPD_OURS) {
-                       state = &params->write;
-                       session_epoch = &session->security_parameters.epoch_write;
-               } else {
-                       state = &params->read;
-                       session_epoch = &session->security_parameters.epoch_read;
-               }
-
        } else {
                ret = _tls13_derive_secret(session, APPLICATION_TRAFFIC_UPDATE,
                                           sizeof(APPLICATION_TRAFFIC_UPDATE)-1,
@@ -254,51 +299,38 @@ _tls13_update_keys(gnutls_session_t session, hs_stage_t stage,
                ret = _tls13_expand_secret(session, "iv", 2, NULL, 0, session->key.hs_skey, iv_size, iv_block);
                if (ret < 0)
                        return gnutls_assert_val(ret);
-
-               if (stage == STAGE_UPD_OURS) {
-                       state = &params->write;
-                       session_epoch = &session->security_parameters.epoch_write;
-               } else {
-                       state = &params->read;
-                       session_epoch = &session->security_parameters.epoch_read;
-               }
        }
 
-       state->mac_secret.data = NULL;
-       state->mac_secret.size = 0;
+       upd_state->mac_secret.data = NULL;
+       upd_state->mac_secret.size = 0;
 
-       ret = _gnutls_set_datum(&state->key, key_block, key_size);
+       ret = _gnutls_set_datum(&upd_state->key, key_block, key_size);
        if (ret < 0)
                return gnutls_assert_val(GNUTLS_E_MEMORY_ERROR);
 
-       _gnutls_hard_log("INT: NEW KEY [%d]: %s\n",
+       _gnutls_hard_log("INT: NEW %s KEY [%d]: %s\n",
+                        (upd_state == &params->read)?"READ":"WRITE",
                         key_size,
                         _gnutls_bin2hex(key_block, key_size,
                                         buf, sizeof(buf), NULL));
 
        if (iv_size > 0) {
-               ret = _gnutls_set_datum(&state->IV, iv_block, iv_size);
+               ret = _gnutls_set_datum(&upd_state->IV, iv_block, iv_size);
                if (ret < 0)
                        return gnutls_assert_val(GNUTLS_E_MEMORY_ERROR);
 
-               _gnutls_hard_log("INT: NEW WRITE IV [%d]: %s\n",
+               _gnutls_hard_log("INT: NEW %s IV [%d]: %s\n",
+                                (upd_state == &params->read)?"READ":"WRITE",
                                 iv_size,
                                 _gnutls_bin2hex(iv_block, iv_size,
                                                 buf, sizeof(buf), NULL));
        }
 
-       ret = _tls13_init_record_state(params->cipher->id, state);
-       if (ret < 0)
-               return gnutls_assert_val(ret);
-
-       *session_epoch = epoch;
-
        return 0;
 }
 
 static int
 _tls13_set_keys(gnutls_session_t session, hs_stage_t stage,
-               uint16_t epoch,
                record_parameters_st * params,
                unsigned iv_size, unsigned key_size)
 {
@@ -314,7 +346,7 @@ _tls13_set_keys(gnutls_session_t session, hs_stage_t stage,
        int ret;
 
        if (stage == STAGE_UPD_OURS || stage == STAGE_UPD_PEERS)
-               return _tls13_update_keys(session, stage, epoch,
+               return _tls13_update_keys(session, stage,
                                          params, iv_size, key_size);
 
        if (stage == STAGE_HS) {
@@ -434,17 +466,6 @@ _tls13_set_keys(gnutls_session_t session, hs_stage_t stage,
                                                 buf, sizeof(buf), NULL));
        }
 
-       ret = _tls13_init_record_state(params->cipher->id, &params->read);
-       if (ret < 0)
-               return gnutls_assert_val(ret);
-
-       ret = _tls13_init_record_state(params->cipher->id, &params->write);
-       if (ret < 0)
-               return gnutls_assert_val(ret);
-
-       session->security_parameters.epoch_read = epoch;
-       session->security_parameters.epoch_write = epoch;
-
        return 0;
 }
 
@@ -525,6 +546,7 @@ _gnutls_set_cipher_suite2(gnutls_session_t session,
 }
 
 /* Sets the next epoch to be a clone of the current one.
+ * The keys are not cloned, only the cipher and MAC.
  */
 int _gnutls_epoch_dup(gnutls_session_t session)
 {
@@ -592,10 +614,17 @@ int _gnutls_epoch_set_keys(gnutls_session_t session, uint16_t epoch, hs_stage_t
 
        if (ver->tls13_sem) {
                ret = _tls13_set_keys
-                   (session, stage, epoch, params, IV_size, key_size);
+                   (session, stage, params, IV_size, key_size);
                if (ret < 0)
                        return gnutls_assert_val(ret);
 
+               ret = _tls13_init_record_state(params->cipher->id, &params->read);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
+
+               ret = _tls13_init_record_state(params->cipher->id, &params->write);
+               if (ret < 0)
+                       return gnutls_assert_val(ret);
        } else {
                ret = _gnutls_set_keys
                    (session, params, hash_size, IV_size, key_size);
@@ -957,10 +986,13 @@ int _tls13_connection_state_init(gnutls_session_t session, hs_stage_t stage)
        if (ret < 0)
                return ret;
 
-       _gnutls_handshake_log("HSK[%p]: TLS 1.3 cipher suite: %s\n",
+       _gnutls_handshake_log("HSK[%p]: TLS 1.3 re-key with cipher suite: %s\n",
                              session,
                              session->security_parameters.cs->name);
 
+       session->security_parameters.epoch_read = epoch_next;
+       session->security_parameters.epoch_write = epoch_next;
+
        return 0;
 }
 
index 4dccef2bebfbfde1277a68412328af50d0690acd..daee408e83ec47951e050c768cfe492e95eb2776 100644 (file)
@@ -167,10 +167,14 @@ typedef enum hs_stage_t {
        STAGE_UPD_PEERS
 } hs_stage_t;
 
-typedef enum record_flush_t {
-       RECORD_FLUSH = 0,
-       RECORD_CORKED,
-} record_flush_t;
+typedef enum record_send_state_t {
+       RECORD_SEND_NORMAL = 0,
+       RECORD_SEND_CORKED, /* corked and transition to NORMAL afterwards */
+       RECORD_SEND_CORKED_TO_KU, /* corked but must transition to RECORD_SEND_KEY_UPDATE_1 */
+       RECORD_SEND_KEY_UPDATE_1,
+       RECORD_SEND_KEY_UPDATE_2,
+       RECORD_SEND_KEY_UPDATE_3
+} record_send_state_t;
 
 /* the maximum size of encrypted packets */
 #define IS_DTLS(session) (session->internals.transport == GNUTLS_DGRAM)
@@ -251,7 +255,8 @@ typedef enum handshake_state_t { STATE0 = 0, STATE1, STATE2,
        STATE30 = 30, STATE31, STATE40 = 40, STATE41, STATE50 = 50,
        STATE90=90, STATE91, STATE92, STATE93,
        STATE100=100, STATE101, STATE102, STATE103, STATE104,
-       STATE105, STATE106, STATE107, STATE108, STATE109, STATE110
+       STATE105, STATE106, STATE107, STATE108, STATE109, STATE110,
+       STATE150 /* key update */
 } handshake_state_t;
 
 typedef enum bye_state_t {
@@ -983,7 +988,9 @@ typedef struct {
                                                 * send.
                                                 */
 
-       record_flush_t record_flush_mode;       /* GNUTLS_FLUSH or GNUTLS_CORKED */
+       record_send_state_t rsend_state;
+       /* buffer used temporarily during key update */
+       gnutls_buffer_st record_key_update_buffer;
        gnutls_buffer_st record_presend_buffer; /* holds cached data
                                                 * for the gnutls_record_send()
                                                 * function.
@@ -1118,12 +1125,8 @@ typedef struct {
 #define HSK_HRR_RECEIVED (1<<4)
 #define HSK_CRT_REQ_SENT (1<<5)
 #define HSK_CRT_REQ_GOT_SIG_ALGO (1<<6)
+#define HSK_KEY_UPDATE_ASKED (1<<7) /* flag is not used during handshake */
        unsigned hsk_flags; /* TLS1.3 only */
-#define KEY_UPDATE_INACTIVE 0
-#define KEY_UPDATE_SCHEDULED 1
-#define KEY_UPDATE_SENT 2
-#define KEY_UPDATE_COMPLETED 3
-       unsigned key_update_state; /* TLS1.3 only */
        time_t last_key_update;
 
        unsigned crt_requested; /* 1 if client auth was requested (i.e., client cert).
index dee7d65f40b02f76d071b71f522f29ee087ef5ab..d7e6d168d130d40e46feaff582b8391ae29949b6 100644 (file)
@@ -361,3 +361,4 @@ _gnutls13_recv_async_handshake(gnutls_session_t session, gnutls_buffer_st *buf)
 
        return 0;
 }
+
index 47bdaf8a3ce138dbbf561cdc64cdbf6046e7c805..785ed63543e1e26e703b70514a3e961e0940fa76 100644 (file)
@@ -1007,6 +1007,9 @@ void gnutls_handshake_set_timeout(gnutls_session_t session,
                                  unsigned int ms);
 int gnutls_rehandshake(gnutls_session_t session);
 
+#define GNUTLS_KU_PEER 1
+int gnutls_session_key_update(gnutls_session_t session, unsigned flags);
+
 gnutls_alert_description_t gnutls_alert_get(gnutls_session_t session);
 int gnutls_alert_send(gnutls_session_t session,
                      gnutls_alert_level_t level,
@@ -2955,7 +2958,7 @@ void gnutls_fips140_set_mode(gnutls_fips_mode_t mode, unsigned flags);
 #define GNUTLS_E_ECC_UNSUPPORTED_CURVE -322
 #define GNUTLS_E_PKCS11_REQUESTED_OBJECT_NOT_AVAILBLE -323
 #define GNUTLS_E_CERTIFICATE_LIST_UNSORTED -324
-#define GNUTLS_E_ILLEGAL_PARAMETER -325
+#define GNUTLS_E_ILLEGAL_PARAMETER -325 /* GNUTLS_A_ILLEGAL_PARAMETER */
 #define GNUTLS_E_NO_PRIORITIES_WERE_SET -326
 #define GNUTLS_E_X509_UNSUPPORTED_EXTENSION -327
 #define GNUTLS_E_SESSION_EOF -328
index faf6345293653a5faa87be8928023f7f28c5d156..8e5444745a67abd44abfff20a8d4de0d55c6a8f8 100644 (file)
@@ -1207,6 +1207,7 @@ GNUTLS_3_6_3
 GNUTLS_3_6_xx
 {
  global:
+       gnutls_session_key_update;
        gnutls_ext_get_current_msg;
 } GNUTLS_3_6_2;
 
index 66d56eb27a4ed564d71742ab557fde3d9df4f0ce..3f2d54386819748b101554c8f29bbe22e343e37b 100644 (file)
@@ -1057,7 +1057,7 @@ record_read_headers(gnutls_session_t session,
                        memset(&record->sequence, 0,
                               sizeof(record->sequence));
                        record->length = _gnutls_read_uint16(&headers[3]);
-                       record->epoch = 0;
+                       record->epoch = session->security_parameters.epoch_read;
                }
 
                _gnutls_record_log
@@ -1658,58 +1658,6 @@ ssize_t append_data_to_corked(gnutls_session_t session, const void *data, size_t
        return data_size;
 }
 
-static
-ssize_t handle_key_update(gnutls_session_t session, const void *data, size_t data_size)
-{
-       ssize_t ret;
-
-       /* do nothing, if we are in corked mode. Otherwise
-        * switch to corked mode, cache the data and send
-        * the key update */
-
-       if (session->internals.record_flush_mode == RECORD_FLUSH) {
-               gnutls_record_cork(session); /* we are not in flush mode after that */
-
-               ret = append_data_to_corked(session, data, data_size);
-               if (ret < 0)
-                       return ret;
-
-               ret = _gnutls13_send_key_update(session, 0);
-
-               session->internals.key_update_state = KEY_UPDATE_SENT;
-               if (ret < 0)
-                       return gnutls_assert_val(ret);
-
-               session->internals.key_update_state = KEY_UPDATE_COMPLETED;
-
-               ret = gnutls_record_uncork(session, 0);
-               if (ret == 0)
-                       session->internals.key_update_state = KEY_UPDATE_INACTIVE;
-               return ret;
-       } else {
-               switch(session->internals.key_update_state) {
-               case KEY_UPDATE_SCHEDULED:
-                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
-
-               case KEY_UPDATE_SENT:
-                       ret = _gnutls13_send_key_update(session, 1);
-                       if (ret < 0)
-                               return gnutls_assert_val(ret);
-
-                       session->internals.key_update_state = KEY_UPDATE_COMPLETED;
-
-                       FALLTHROUGH;
-               case KEY_UPDATE_COMPLETED:
-                       ret = gnutls_record_uncork(session, 0);
-                       if (ret == 0)
-                               session->internals.key_update_state = KEY_UPDATE_INACTIVE;
-                       return ret;
-               default:
-                       /* no state */
-                       return GNUTLS_E_INT_RET_0; /* notify fall through */
-               }
-       }
-}
 /**
  * gnutls_record_send:
  * @session: is a #gnutls_session_t type.
@@ -1749,6 +1697,8 @@ ssize_t
 gnutls_record_send(gnutls_session_t session, const void *data,
                   size_t data_size)
 {
+       int ret;
+
        if (unlikely(!session->internals.initial_negotiation_completed)) {
                /* this is to protect buggy applications from sending unencrypted
                 * data. We allow sending however, if we are in false start handshake
@@ -1757,23 +1707,45 @@ gnutls_record_send(gnutls_session_t session, const void *data,
                        return gnutls_assert_val(GNUTLS_E_UNAVAILABLE_DURING_HANDSHAKE);
        }
 
-       if (session->internals.key_update_state > KEY_UPDATE_INACTIVE) {
-               ssize_t ret;
-
-               ret = handle_key_update(session, data, data_size);
-               if (ret != GNUTLS_E_INT_RET_0)
-                       return ret;
-               /* otherwise fall through */
-       }
+       switch(session->internals.rsend_state) {
+               case RECORD_SEND_NORMAL:
+                       return _gnutls_send_int(session, GNUTLS_APPLICATION_DATA,
+                                               -1, EPOCH_WRITE_CURRENT, data,
+                                               data_size, MBUFFER_FLUSH);
+               case RECORD_SEND_CORKED:
+               case RECORD_SEND_CORKED_TO_KU:
+                       return append_data_to_corked(session, data, data_size);
+               case RECORD_SEND_KEY_UPDATE_1:
+                       _gnutls_buffer_reset(&session->internals.record_key_update_buffer);
+
+                       ret = _gnutls_buffer_append_data(&session->internals.record_key_update_buffer,
+                                                        data, data_size);
+                       if (ret < 0)
+                               return gnutls_assert_val(ret);
 
-       if (session->internals.record_flush_mode == RECORD_FLUSH) {
-               return _gnutls_send_int(session, GNUTLS_APPLICATION_DATA,
-                                       -1, EPOCH_WRITE_CURRENT, data,
-                                       data_size, MBUFFER_FLUSH);
-       } else {                /* GNUTLS_CORKED */
-               return append_data_to_corked(session, data, data_size);
+                       session->internals.rsend_state = RECORD_SEND_KEY_UPDATE_2;
+                       /* fall-through */
+               case RECORD_SEND_KEY_UPDATE_2:
+                       ret = gnutls_session_key_update(session, 0);
+                       if (ret < 0)
+                               return gnutls_assert_val(ret);
 
+                       session->internals.rsend_state = RECORD_SEND_KEY_UPDATE_3;
+                       /* fall-through */
+               case RECORD_SEND_KEY_UPDATE_3:
+                       ret = _gnutls_send_int(session, GNUTLS_APPLICATION_DATA,
+                                               -1, EPOCH_WRITE_CURRENT,
+                                               session->internals.record_key_update_buffer.data,
+                                               session->internals.record_key_update_buffer.length,
+                                               MBUFFER_FLUSH);
+                       _gnutls_buffer_clear(&session->internals.record_key_update_buffer);
+                       session->internals.rsend_state = RECORD_SEND_NORMAL;
+                       if (ret < 0)
+                               gnutls_assert();
 
+                       return ret;
+               default:
+                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
        }
 }
 
@@ -1790,7 +1762,7 @@ gnutls_record_send(gnutls_session_t session, const void *data,
  **/
 void gnutls_record_cork(gnutls_session_t session)
 {
-       session->internals.record_flush_mode = RECORD_CORKED;
+       session->internals.rsend_state = RECORD_SEND_CORKED;
 }
 
 /**
@@ -1818,12 +1790,15 @@ int gnutls_record_uncork(gnutls_session_t session, unsigned int flags)
 {
        int ret;
        ssize_t total = 0;
+       record_send_state_t orig_state = session->internals.rsend_state;
 
-       if (session->internals.record_flush_mode == RECORD_FLUSH)
+       if (orig_state == RECORD_SEND_CORKED)
+               session->internals.rsend_state = RECORD_SEND_NORMAL;
+       else if (orig_state == RECORD_SEND_CORKED_TO_KU)
+               session->internals.rsend_state = RECORD_SEND_KEY_UPDATE_1;
+       else
                return 0;       /* nothing to be done */
 
-       session->internals.record_flush_mode = RECORD_FLUSH;
-
        while (session->internals.record_presend_buffer.length > 0) {
                if (flags == GNUTLS_RECORD_WAIT) {
                        do {
@@ -1857,7 +1832,7 @@ int gnutls_record_uncork(gnutls_session_t session, unsigned int flags)
        return total;
 
       fail:
-       session->internals.record_flush_mode = RECORD_CORKED;
+       session->internals.rsend_state = orig_state;
        return ret;
 }
 
index 1aeddc01ac16519e6e48ed1d4e22b8304adce338..d48b311347a7926dac878df0fea70e1420d7749f 100644 (file)
@@ -294,6 +294,7 @@ int gnutls_init(gnutls_session_t * session, unsigned int flags)
        _gnutls_buffer_init(&(*session)->internals.hb_remote_data);
        _gnutls_buffer_init(&(*session)->internals.hb_local_data);
        _gnutls_buffer_init(&(*session)->internals.record_presend_buffer);
+       _gnutls_buffer_init(&(*session)->internals.record_key_update_buffer);
 
        _mbuffer_head_init(&(*session)->internals.record_buffer);
        _mbuffer_head_init(&(*session)->internals.record_send_buffer);
@@ -412,6 +413,7 @@ void gnutls_deinit(gnutls_session_t session)
        _gnutls_buffer_clear(&session->internals.hb_remote_data);
        _gnutls_buffer_clear(&session->internals.hb_local_data);
        _gnutls_buffer_clear(&session->internals.record_presend_buffer);
+       _gnutls_buffer_clear(&session->internals.record_key_update_buffer);
 
        _mbuffer_head_clear(&session->internals.record_buffer);
        _mbuffer_head_clear(&session->internals.record_recv_buffer);
index 59db784e5b7df850f5f510f1237ea3b0540fbeec..b93f1c289a9e760fbb3058b6ba732a68de2a5dcb 100644 (file)
@@ -69,7 +69,7 @@ int _gnutls13_recv_key_update(gnutls_session_t session, gnutls_buffer_st *buf)
 
        _gnutls_epoch_gc(session);
 
-       _gnutls_handshake_log("HSK[%p]: requested TLS 1.3 key update (%u)\n",
+       _gnutls_handshake_log("HSK[%p]: received TLS 1.3 key update (%u)\n",
                              session, (unsigned)buf->data[0]);
 
        switch(buf->data[0]) {
@@ -81,6 +81,12 @@ int _gnutls13_recv_key_update(gnutls_session_t session, gnutls_buffer_st *buf)
 
                break;
        case 1:
+               if (session->internals.hsk_flags & HSK_KEY_UPDATE_ASKED) {
+                       /* if we had asked a key update we shouldn't get this
+                        * reply */
+                       return gnutls_assert_val(GNUTLS_E_ILLEGAL_PARAMETER);
+               }
+
                /* peer updated its key, requested our key update */
                ret = update_keys(session, STAGE_UPD_PEERS);
                if (ret < 0)
@@ -90,48 +96,107 @@ int _gnutls13_recv_key_update(gnutls_session_t session, gnutls_buffer_st *buf)
                 * will be performed prior to sending the next application
                 * message.
                 */
-               session->internals.key_update_state = KEY_UPDATE_SCHEDULED;
+               if (session->internals.rsend_state == RECORD_SEND_NORMAL)
+                       session->internals.rsend_state = RECORD_SEND_KEY_UPDATE_1;
+               else if (session->internals.rsend_state == RECORD_SEND_CORKED)
+                       session->internals.rsend_state = RECORD_SEND_CORKED_TO_KU;
+               else
+                       return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_PARAMETER);
 
                break;
        default:
                return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_PARAMETER);
        }
 
+       session->internals.hsk_flags &= ~(unsigned)(HSK_KEY_UPDATE_ASKED);
+
        return 0;
 }
 
-int _gnutls13_send_key_update(gnutls_session_t session, unsigned again)
+int _gnutls13_send_key_update(gnutls_session_t session, unsigned again, unsigned flags /* GNUTLS_KU_* */)
 {
-       int ret, ret2;
+       int ret;
        mbuffer_st *bufel = NULL;
-       const uint8_t val = 0;
+       uint8_t val;
 
        if (again == 0) {
-               _gnutls_handshake_log("HSK[%p]: sending key update\n", session);
+               if (flags & GNUTLS_KU_PEER) {
+                       /* mark that we asked a key update to prevent an
+                        * infinite ping pong when receiving the reply */
+                       session->internals.hsk_flags |= HSK_KEY_UPDATE_ASKED;
+                       val = 0x01;
+               } else {
+                       val = 0x00;
+               }
+
+               _gnutls_handshake_log("HSK[%p]: sending key update (%u)\n", session, (unsigned)val);
 
                bufel = _gnutls_handshake_alloc(session, 1);
                if (bufel == NULL)
                        return gnutls_assert_val(GNUTLS_E_MEMORY_ERROR);
 
                _mbuffer_set_udata_size(bufel, 0);
-               ret = _mbuffer_append_data(bufel, (void*)&val, 1);
+               ret = _mbuffer_append_data(bufel, &val, 1);
                if (ret < 0) {
                        gnutls_assert();
                        goto cleanup;
                }
-       }
 
-       ret = _gnutls_send_handshake(session, bufel, GNUTLS_HANDSHAKE_KEY_UPDATE);
-       if (ret == 0) {
-               /* it was completely sent, update the keys */
-               ret2 = update_keys(session, STAGE_UPD_OURS);
-               if (ret2 < 0)
-                       return gnutls_assert_val(ret2);
        }
 
-       return ret;
+       return _gnutls_send_handshake(session, bufel, GNUTLS_HANDSHAKE_KEY_UPDATE);
 
 cleanup:
        _mbuffer_xfree(&bufel);
        return ret;
 }
+
+/**
+ * gnutls_session_key_update:
+ * @session: is a #gnutls_session_t type.
+ * @flags: zero of %GNUTLS_KU_PEER
+ *
+ * This function will update/refresh the session keys when the
+ * TLS protocol is 1.3 or better. The peer is notified of the
+ * update by sending a message, so this function should be
+ * treated similarly to gnutls_record_send() --i.e., it may
+ * return %GNUTLS_E_AGAIN or %GNUTLS_E_INTERRUPTED.
+ *
+ * When this flag %GNUTLS_KU_PEER is specified, this function
+ * in addition to updating the local keys, will ask the peer to
+ * refresh its keys too.
+ *
+ * If the negotiated version is not TLS 1.3 or better this
+ * function will return %GNUTLS_E_INVALID_REQUEST.
+ *
+ * Returns: %GNUTLS_E_SUCCESS on success, otherwise a negative error code.
+ *
+ * Since: 3.6.xx
+ **/
+int gnutls_session_key_update(gnutls_session_t session, unsigned flags)
+{
+       int ret;
+       const version_entry_st *vers = get_version(session);
+
+       if (!vers->tls13_sem)
+               return GNUTLS_E_INVALID_REQUEST;
+
+       ret =
+           _gnutls13_send_key_update(session, AGAIN(STATE150), flags);
+       STATE = STATE150;
+
+       if (ret < 0) {
+               gnutls_assert();
+               return ret;
+       }
+       STATE = STATE0;
+
+       _gnutls_epoch_gc(session);
+
+       /* it was completely sent, update the keys */
+       ret = update_keys(session, STAGE_UPD_OURS);
+       if (ret < 0)
+               return gnutls_assert_val(ret);
+
+       return 0;
+}
index 0b313581bb1c66756719e977574e864cb1358075..41038cb3bf16cbac882d7c740e4f71551917b88d 100644 (file)
@@ -21,4 +21,4 @@
  */
 
 int _gnutls13_recv_key_update(gnutls_session_t session, gnutls_buffer_st *buf);
-int _gnutls13_send_key_update(gnutls_session_t session, unsigned again);
+int _gnutls13_send_key_update(gnutls_session_t session, unsigned again, unsigned flags);