]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
extensions: apply extension msg type restrictions
authorNikos Mavrogiannopoulos <nmav@gnutls.org>
Mon, 11 Sep 2017 08:26:44 +0000 (10:26 +0200)
committerNikos Mavrogiannopoulos <nmav@redhat.com>
Mon, 19 Feb 2018 14:29:33 +0000 (15:29 +0100)
That is, on the extension parsing functions ensure that
no extension which are not valid for the currently
received message are parsed.

Signed-off-by: Nikos Mavrogiannopoulos <nmav@gnutls.org>
lib/extensions.c
lib/extensions.h
lib/handshake.c

index 618469dc134a6c425528e9b241a4e165ab031cfb..884e2ef7a69fb8cbdb93d022a85d3b10230ac4c1 100644 (file)
@@ -199,6 +199,7 @@ void _gnutls_extension_list_add_sr(gnutls_session_t session)
 
 int
 _gnutls_parse_extensions(gnutls_session_t session,
+                        gnutls_ext_flags_t msg,
                         gnutls_ext_parse_type_t parse_type,
                         const uint8_t * data, int data_size)
 {
@@ -255,6 +256,15 @@ _gnutls_parse_extensions(gnutls_session_t session,
                        continue;
                }
 
+
+               if ((ext->validity & msg) == 0) {
+
+                       _gnutls_debug_log("EXT[%p]: Received unexpected extension (%s/%d) for '%s'\n", session,
+                                         gnutls_ext_get_name(id), (int)id,
+                                         ext_msg_validity_to_str(msg));
+                       return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_EXTENSION);
+               }
+
                if (session->security_parameters.entity == GNUTLS_SERVER) {
                        ret = _gnutls_extension_list_add(session, ext, 1);
                        if (ret == 0)
@@ -283,7 +293,9 @@ _gnutls_parse_extensions(gnutls_session_t session,
 
 static
 int send_extension(gnutls_session_t session, const extension_entry_st *p,
-                  gnutls_buffer_st *extdata, gnutls_ext_parse_type_t parse_type)
+                  gnutls_buffer_st *extdata,
+                  gnutls_ext_flags_t msg,
+                  gnutls_ext_parse_type_t parse_type)
 {
        int size_pos, appended, ret;
        size_t size_prev;
@@ -295,6 +307,13 @@ int send_extension(gnutls_session_t session, const extension_entry_st *p,
            && p->parse_type != parse_type)
                return 0;
 
+       if ((msg & p->validity) == 0) {
+               _gnutls_handshake_log("EXT[%p]: Not sending extension (%s/%d) for '%s'\n", session,
+                                 gnutls_ext_get_name(p->id), (int)p->id,
+                                 ext_msg_validity_to_str(msg));
+               return 0;
+       }
+
        /* ensure we don't send something twice (i.e, overriden extensions in
         * client), and ensure we are sending only what we received in server. */
        ret = _gnutls_extension_list_check(session, p->id);
@@ -352,6 +371,7 @@ int send_extension(gnutls_session_t session, const extension_entry_st *p,
 int
 _gnutls_gen_extensions(gnutls_session_t session,
                       gnutls_buffer_st * extdata,
+                      gnutls_ext_flags_t msg,
                       gnutls_ext_parse_type_t parse_type)
 {
        int size;
@@ -365,7 +385,7 @@ _gnutls_gen_extensions(gnutls_session_t session,
                return gnutls_assert_val(ret);
 
        for (i=0; i < session->internals.rexts_size; i++) {
-               ret = send_extension(session, &session->internals.rexts[i], extdata, parse_type);
+               ret = send_extension(session, &session->internals.rexts[i], extdata, msg, parse_type);
                if (ret < 0)
                        return gnutls_assert_val(ret);
        }
@@ -373,7 +393,7 @@ _gnutls_gen_extensions(gnutls_session_t session,
        /* send_extension() ensures we don't send duplicates, in case
         * of overriden extensions */
        for (i = 0; extfunc[i] != NULL; i++) {
-               ret = send_extension(session, extfunc[i], extdata, parse_type);
+               ret = send_extension(session, extfunc[i], extdata, msg, parse_type);
                if (ret < 0)
                        return gnutls_assert_val(ret);
        }
@@ -697,6 +717,9 @@ _gnutls_ext_get_resumed_session_data(gnutls_session_t session,
  * structure using gnutls_ext_set_data(), and they can be retrieved using
  * gnutls_ext_get_data().
  *
+ * Any extensions registered with this function are valid for the client
+ * and TLS1.2 server hello (or encrypted extensions for TLS1.3).
+ *
  * This function is not thread safe.
  *
  * Returns: %GNUTLS_E_SUCCESS on success, otherwise a negative error code.
@@ -731,6 +754,7 @@ gnutls_ext_register(const char *name, int id, gnutls_ext_parse_type_t parse_type
        tmp_mod->deinit_func = deinit_func;
        tmp_mod->pack_func = pack_func;
        tmp_mod->unpack_func = unpack_func;
+       tmp_mod->validity = GNUTLS_EXT_FLAG_CLIENT_HELLO|GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO|GNUTLS_EXT_FLAG_EE;
 
        ret = ext_register(tmp_mod);
        if (ret < 0) {
@@ -740,6 +764,11 @@ gnutls_ext_register(const char *name, int id, gnutls_ext_parse_type_t parse_type
        return ret;
 }
 
+#define VALIDITY_MASK (GNUTLS_EXT_FLAG_CLIENT_HELLO|GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO| \
+                       GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO| \
+                       GNUTLS_EXT_FLAG_EE|GNUTLS_EXT_FLAG_CT|GNUTLS_EXT_FLAG_CR| \
+                       GNUTLS_EXT_FLAG_NST|GNUTLS_EXT_FLAG_HRR)
+
 /**
  * gnutls_session_ext_register:
  * @session: the session for which this extension will be set
@@ -766,6 +795,10 @@ gnutls_ext_register(const char *name, int id, gnutls_ext_parse_type_t parse_type
  * structure using gnutls_ext_set_data(), and they can be retrieved using
  * gnutls_ext_get_data().
  *
+ * The validity of the extension registered can be given by the appropriate flags
+ * of %gnutls_ext_flags_t. If no validity is given, then the registered extension
+ * will be valid for client and TLS1.2 server hello (or encrypted extensions for TLS1.3).
+ *
  * Returns: %GNUTLS_E_SUCCESS on success, otherwise a negative error code.
  *
  * Since: 3.5.5
@@ -803,6 +836,11 @@ gnutls_session_ext_register(gnutls_session_t session,
        tmp_mod.deinit_func = deinit_func;
        tmp_mod.pack_func = pack_func;
        tmp_mod.unpack_func = unpack_func;
+       tmp_mod.validity = flags;
+
+       if ((tmp_mod.validity & VALIDITY_MASK) == 0) {
+               tmp_mod.validity = GNUTLS_EXT_FLAG_CLIENT_HELLO|GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO|GNUTLS_EXT_FLAG_EE;
+       }
 
        exts = gnutls_realloc(session->internals.rexts, (session->internals.rexts_size+1)*sizeof(*exts));
        if (exts == NULL) {
index b06ff6c94c52555347ba367c65e3336ea1099b4b..3ef9d445c2a76e1bbe7f9a7e1c5efb305f2eed75 100644 (file)
 #include <gnutls/gnutls.h>
 
 int _gnutls_parse_extensions(gnutls_session_t session,
+                            gnutls_ext_flags_t msg,
                             gnutls_ext_parse_type_t parse_type,
                             const uint8_t * data, int data_size);
 int _gnutls_gen_extensions(gnutls_session_t session,
                           gnutls_buffer_st * extdata,
+                          gnutls_ext_flags_t msg,
                           gnutls_ext_parse_type_t);
 int _gnutls_ext_init(void);
 void _gnutls_ext_deinit(void);
@@ -56,6 +58,30 @@ int _gnutls_ext_pack(gnutls_session_t session, gnutls_buffer_st * packed);
 int _gnutls_ext_unpack(gnutls_session_t session,
                       gnutls_buffer_st * packed);
 
+inline static const char *ext_msg_validity_to_str(gnutls_ext_flags_t msg)
+{
+       switch(msg) {
+               case GNUTLS_EXT_FLAG_CLIENT_HELLO:
+                       return "client hello";
+               case GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO:
+                       return "TLS 1.2 server hello";
+               case GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO:
+                       return "TLS 1.3 server hello";
+               case GNUTLS_EXT_FLAG_EE:
+                       return "encrypted extensions";
+               case GNUTLS_EXT_FLAG_CT:
+                       return "certificate";
+               case GNUTLS_EXT_FLAG_CR:
+                       return "certificate request";
+               case GNUTLS_EXT_FLAG_NST:
+                       return "new session ticket";
+               case GNUTLS_EXT_FLAG_HRR:
+                       return "hello retry request";
+               default:
+                       return "(unknown)";
+       }
+}
+
 typedef struct extension_entry_st {
        const char *name; /* const overriden when free_struct is set */
        unsigned free_struct;
index d4d20afa6c6b81da741dc34d105984a0da6b16ee..da1e24068aee4b4399776d1ba0c801f2b7bf7aee 100644 (file)
@@ -544,7 +544,8 @@ read_client_hello(gnutls_session_t session, uint8_t * data,
         * resumed ones.
         */
        ret =
-           _gnutls_parse_extensions(session, GNUTLS_EXT_MANDATORY,
+           _gnutls_parse_extensions(session, GNUTLS_EXT_FLAG_CLIENT_HELLO,
+                                    GNUTLS_EXT_MANDATORY,
                                     ext_ptr, ext_size);
        if (ret < 0) {
                gnutls_assert();
@@ -584,7 +585,8 @@ read_client_hello(gnutls_session_t session, uint8_t * data,
         * Unconditionally try to parse extensions; safe renegotiation uses them in
         * sslv3 and higher, even though sslv3 doesn't officially support them.
         */
-       ret = _gnutls_parse_extensions(session, GNUTLS_EXT_APPLICATION,
+       ret = _gnutls_parse_extensions(session, GNUTLS_EXT_FLAG_CLIENT_HELLO,
+                                      GNUTLS_EXT_APPLICATION,
                                       ext_ptr, ext_size);
        /* len is the rest of the parsed length */
        if (ret < 0) {
@@ -601,7 +603,8 @@ read_client_hello(gnutls_session_t session, uint8_t * data,
 
        /* Session tickets are parsed in this point */
        ret =
-           _gnutls_parse_extensions(session, GNUTLS_EXT_TLS, ext_ptr, ext_size);
+           _gnutls_parse_extensions(session, GNUTLS_EXT_FLAG_CLIENT_HELLO,
+                                    GNUTLS_EXT_TLS, ext_ptr, ext_size);
        if (ret < 0) {
                gnutls_assert();
                return ret;
@@ -655,7 +658,8 @@ read_client_hello(gnutls_session_t session, uint8_t * data,
        /* call extensions that are intended to be parsed after the ciphersuite/cert
         * are known. */
        ret =
-           _gnutls_parse_extensions(session, _GNUTLS_EXT_TLS_POST_CS, ext_ptr, ext_size);
+           _gnutls_parse_extensions(session, GNUTLS_EXT_FLAG_CLIENT_HELLO,
+                                    _GNUTLS_EXT_TLS_POST_CS, ext_ptr, ext_size);
        if (ret < 0) {
                gnutls_assert();
                return ret;
@@ -1516,6 +1520,8 @@ read_server_hello(gnutls_session_t session,
        gnutls_protocol_t version;
        int len = datalen;
        const version_entry_st *vers;
+       gnutls_ext_flags_t ext_parse_flag;
+
        if (datalen < GNUTLS_RANDOM_SIZE+2) {
                gnutls_assert();
                return GNUTLS_E_UNEXPECTED_PACKET_LENGTH;
@@ -1566,7 +1572,8 @@ read_server_hello(gnutls_session_t session,
                        DECR_LEN(len, 2 + 1);
 
                        ret =
-                           _gnutls_parse_extensions(session, GNUTLS_EXT_MANDATORY,
+                           _gnutls_parse_extensions(session, GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO,
+                                                    GNUTLS_EXT_MANDATORY,
                                                     &data[pos], len);
                        if (ret < 0) {
                                gnutls_assert();
@@ -1595,31 +1602,42 @@ read_server_hello(gnutls_session_t session,
                 */
                DECR_LEN(len, 1);
                pos++;
+               ext_parse_flag = GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO;
+       } else {
+               ext_parse_flag = GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO;
        }
 
        /* Parse extensions in order.
         */
        ret =
-           _gnutls_parse_extensions(session, GNUTLS_EXT_MANDATORY, &data[pos],
-                                    len);
+           _gnutls_parse_extensions(session,
+                                    ext_parse_flag,
+                                    GNUTLS_EXT_MANDATORY,
+                                    &data[pos], len);
        if (ret < 0)
                return gnutls_assert_val(ret);
 
        ret =
-           _gnutls_parse_extensions(session, GNUTLS_EXT_APPLICATION, &data[pos],
-                                    len);
+           _gnutls_parse_extensions(session,
+                                    ext_parse_flag,
+                                    GNUTLS_EXT_APPLICATION,
+                                    &data[pos], len);
        if (ret < 0)
                return gnutls_assert_val(ret);
 
        ret =
-           _gnutls_parse_extensions(session, GNUTLS_EXT_TLS, &data[pos],
-                                    len);
+           _gnutls_parse_extensions(session,
+                                    ext_parse_flag,
+                                    GNUTLS_EXT_TLS,
+                                    &data[pos], len);
        if (ret < 0)
                return gnutls_assert_val(ret);
 
        ret =
-           _gnutls_parse_extensions(session, _GNUTLS_EXT_TLS_POST_CS, &data[pos],
-                                    len);
+           _gnutls_parse_extensions(session,
+                                    ext_parse_flag,
+                                    _GNUTLS_EXT_TLS_POST_CS,
+                                    &data[pos], len);
        if (ret < 0)
                return gnutls_assert_val(ret);
 
@@ -1818,6 +1836,7 @@ static int send_client_hello(gnutls_session_t session, int again)
 
                        ret =
                            _gnutls_gen_extensions(session, &extdata,
+                                                  GNUTLS_EXT_FLAG_CLIENT_HELLO,
                                                   type);
                        if (ret < 0) {
                                gnutls_assert();
@@ -1866,6 +1885,7 @@ static int send_server_hello(gnutls_session_t session, int again)
            session->security_parameters.session_id_size;
        char buf[2 * GNUTLS_MAX_SESSION_ID_SIZE + 1];
        const version_entry_st *vers;
+       gnutls_ext_flags_t ext_parse_flag;
 
        _gnutls_buffer_init(&extdata);
 
@@ -1874,8 +1894,14 @@ static int send_server_hello(gnutls_session_t session, int again)
                if (unlikely(vers == NULL))
                        return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
 
+               if (vers->tls13_sem)
+                       ext_parse_flag = GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO;
+               else
+                       ext_parse_flag = GNUTLS_EXT_FLAG_TLS12_SERVER_HELLO;
+
                ret =
                    _gnutls_gen_extensions(session, &extdata,
+                                          ext_parse_flag,
                                           (session->internals.resumed ==
                                            RESUME_TRUE) ?
                                           GNUTLS_EXT_MANDATORY :