]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
extensions: avoid looping to discover location of saved data
authorNikos Mavrogiannopoulos <nmav@redhat.com>
Fri, 29 Sep 2017 13:40:36 +0000 (15:40 +0200)
committerNikos Mavrogiannopoulos <nmav@redhat.com>
Mon, 19 Feb 2018 14:29:35 +0000 (15:29 +0100)
Signed-off-by: Nikos Mavrogiannopoulos <nmav@redhat.com>
lib/gnutls_int.h
lib/hello_ext.c

index e110631d89409b6176e116ff3e2ba3c74a1819db..d7a4e7182dbe29ca3ff922d2ca85ea425a94b5b9 100644 (file)
@@ -278,9 +278,10 @@ typedef enum recv_state_t {
  */
 #define MAX_ALGOS GNUTLS_MAX_ALGORITHM_NUM
 
-/* IDs are non-zero and allocated in a way that all values fit in 64-bit integer as (1<<val) */
+/* IDs are allocated in a way that all values fit in 64-bit integer as (1<<val) */
 typedef enum extensions_t {
-       GNUTLS_EXTENSION_MAX_RECORD_SIZE = 1,
+       GNUTLS_EXTENSION_INVALID = 0xffff,
+       GNUTLS_EXTENSION_MAX_RECORD_SIZE = 0,
        GNUTLS_EXTENSION_STATUS_REQUEST,
        GNUTLS_EXTENSION_CERT_TYPE,
        GNUTLS_EXTENSION_SUPPORTED_ECC,
@@ -1177,8 +1178,7 @@ typedef struct {
        struct hello_ext_entry_st *rexts;
        unsigned rexts_size;
 
-       struct {
-               extensions_t id;
+       struct { /* ext_data[id] contains data for extension_t id */
                gnutls_ext_priv_data_t priv;
                gnutls_ext_priv_data_t resumed_priv;
                uint8_t set;
index 6d1927bcc54b620c50a736b48d4dc1d5a6f6686a..5cd7a166cb4584483e328bd98929d3ebb307ff1c 100644 (file)
@@ -144,6 +144,8 @@ const char *gnutls_ext_get_name(unsigned int ext)
        return NULL;
 }
 
+/* Returns %GNUTLS_EXTENSION_INVALID on error
+ */
 static unsigned tls_id_to_gid(gnutls_session_t session, unsigned tls_id)
 {
        unsigned i;
@@ -158,7 +160,7 @@ static unsigned tls_id_to_gid(gnutls_session_t session, unsigned tls_id)
                        return extfunc[i]->gid;
        }
 
-       return 0;
+       return GNUTLS_EXTENSION_INVALID;
 }
 
 typedef struct hello_ext_ctx_st {
@@ -178,7 +180,7 @@ int hello_ext_parse(void *_ctx, uint16_t tls_id, const uint8_t *data, int data_s
        int ret;
 
        id = tls_id_to_gid(session, tls_id);
-       if (id == 0) { /* skip */
+       if (id == GNUTLS_EXTENSION_INVALID) { /* skip */
                return 0;
        }
 
@@ -473,24 +475,20 @@ _gnutls_ext_set_resumed_session_data(gnutls_session_t session,
                                     extensions_t id,
                                     gnutls_ext_priv_data_t data)
 {
-       int i;
        const struct hello_ext_entry_st *ext;
 
-       ext = _gnutls_ext_ptr(session, id, GNUTLS_EXT_ANY);
+       /* If this happens we need to increase the max */
+       assert(id < MAX_EXT_TYPES);
 
-       for (i = 0; i < MAX_EXT_TYPES; i++) {
-               if (session->internals.ext_data[i].id == id
-                   || (!session->internals.ext_data[i].resumed_set && !session->internals.ext_data[i].set)) {
+       ext = _gnutls_ext_ptr(session, id, GNUTLS_EXT_ANY);
+       assert(ext != NULL);
 
-                       if (session->internals.ext_data[i].resumed_set != 0)
-                               unset_resumed_ext_data(session, ext, i);
+       if (session->internals.ext_data[id].resumed_set != 0)
+               unset_resumed_ext_data(session, ext, id);
 
-                       session->internals.ext_data[i].id = id;
-                       session->internals.ext_data[i].resumed_priv = data;
-                       session->internals.ext_data[i].resumed_set = 1;
-                       return;
-               }
-       }
+       session->internals.ext_data[id].resumed_priv = data;
+       session->internals.ext_data[id].resumed_set = 1;
+       return;
 }
 
 int _gnutls_hello_ext_unpack(gnutls_session_t session, gnutls_buffer_st * packed)
@@ -550,19 +548,13 @@ unset_ext_data(gnutls_session_t session, const struct hello_ext_entry_st *ext, u
 
 void
 _gnutls_hello_ext_unset_sdata(gnutls_session_t session,
-                               extensions_t id)
+                             extensions_t id)
 {
-       int i;
        const struct hello_ext_entry_st *ext;
 
        ext = _gnutls_ext_ptr(session, id, GNUTLS_EXT_ANY);
-
-       for (i = 0; i < MAX_EXT_TYPES; i++) {
-               if (session->internals.ext_data[i].id == id) {
-                       unset_ext_data(session, ext, i);
-                       return;
-               }
-       }
+       if (ext)
+               unset_ext_data(session, ext, id);
 }
 
 static void unset_resumed_ext_data(gnutls_session_t session, const struct hello_ext_entry_st *ext, unsigned idx)
@@ -587,10 +579,11 @@ void _gnutls_hello_ext_sdata_deinit(gnutls_session_t session)
                if (!session->internals.ext_data[i].set && !session->internals.ext_data[i].resumed_set)
                        continue;
 
-               ext = _gnutls_ext_ptr(session, session->internals.ext_data[i].id, GNUTLS_EXT_ANY);
-
-               unset_ext_data(session, ext, i);
-               unset_resumed_ext_data(session, ext, i);
+               ext = _gnutls_ext_ptr(session, i, GNUTLS_EXT_ANY);
+               if (ext) {
+                       unset_ext_data(session, ext, i);
+                       unset_resumed_ext_data(session, ext, i);
+               }
        }
 }
 
@@ -602,41 +595,32 @@ void
 _gnutls_hello_ext_set_sdata(gnutls_session_t session, extensions_t id,
                             gnutls_ext_priv_data_t data)
 {
-       unsigned int i;
        const struct hello_ext_entry_st *ext;
 
-       ext = _gnutls_ext_ptr(session, id, GNUTLS_EXT_ANY);
+       assert(id < MAX_EXT_TYPES);
 
-       for (i = 0; i < MAX_EXT_TYPES; i++) {
-               if (session->internals.ext_data[i].id == id ||
-                   (!session->internals.ext_data[i].set && !session->internals.ext_data[i].resumed_set)) {
+       ext = _gnutls_ext_ptr(session, id, GNUTLS_EXT_ANY);
+       assert(ext != NULL);
 
-                       if (session->internals.ext_data[i].set != 0) {
-                               unset_ext_data(session, ext, i);
-                       }
-                       session->internals.ext_data[i].id = id;
-                       session->internals.ext_data[i].priv = data;
-                       session->internals.ext_data[i].set = 1;
-                       return;
-               }
+       if (session->internals.ext_data[id].set != 0) {
+               unset_ext_data(session, ext, id);
        }
+       session->internals.ext_data[id].priv = data;
+       session->internals.ext_data[id].set = 1;
+
+       return;
 }
 
 int
 _gnutls_hello_ext_get_sdata(gnutls_session_t session,
                            extensions_t id, gnutls_ext_priv_data_t * data)
 {
-       int i;
-
-       for (i = 0; i < MAX_EXT_TYPES; i++) {
-               if (session->internals.ext_data[i].set != 0 &&
-                   session->internals.ext_data[i].id == id)
-               {
-                       *data =
-                           session->internals.ext_data[i].priv;
-                       return 0;
-               }
+       if (session->internals.ext_data[id].set != 0) {
+               *data =
+                   session->internals.ext_data[id].priv;
+               return 0;
        }
+
        return GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE;
 }
 
@@ -645,16 +629,12 @@ _gnutls_hello_ext_get_resumed_sdata(gnutls_session_t session,
                                    extensions_t id,
                                    gnutls_ext_priv_data_t * data)
 {
-       int i;
-
-       for (i = 0; i < MAX_EXT_TYPES; i++) {
-               if (session->internals.ext_data[i].resumed_set != 0
-                   && session->internals.ext_data[i].id == id) {
-                       *data =
-                           session->internals.ext_data[i].resumed_priv;
-                       return 0;
-               }
+       if (session->internals.ext_data[id].resumed_set != 0) {
+               *data =
+                   session->internals.ext_data[id].resumed_priv;
+               return 0;
        }
+
        return GNUTLS_E_INVALID_REQUEST;
 }
 
@@ -856,7 +836,7 @@ gnutls_ext_set_data(gnutls_session_t session, unsigned tls_id,
                    gnutls_ext_priv_data_t data)
 {
        unsigned id = tls_id_to_gid(session, tls_id);
-       if (id == 0)
+       if (id == GNUTLS_EXTENSION_INVALID)
                return;
 
        _gnutls_hello_ext_set_sdata(session, id, data);
@@ -879,7 +859,7 @@ gnutls_ext_get_data(gnutls_session_t session,
                    unsigned tls_id, gnutls_ext_priv_data_t *data)
 {
        unsigned id = tls_id_to_gid(session, tls_id);
-       if (id == 0)
+       if (id == GNUTLS_EXTENSION_INVALID)
                return gnutls_assert_val(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
 
        return _gnutls_hello_ext_get_sdata(session, id, data);