]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
app-layer-ssl: split function into multiple smaller functions
authorMats Klepsland <mats.klepsland@gmail.com>
Mon, 27 Nov 2017 11:23:57 +0000 (12:23 +0100)
committerVictor Julien <victor@inliniac.net>
Tue, 20 Mar 2018 15:27:22 +0000 (16:27 +0100)
Split 'TLSDecodeHandshakeHello' into smaller functions to make
it easier to read the code when the function grows in size.

src/app-layer-ssl.c

index 8ab41cd042eb25a02c10c4cf08d36dbb02fd0d79..d4c36b272e53605d6da71c297b762cf3bb412e03 100644 (file)
@@ -134,7 +134,7 @@ SslConfig ssl_config;
 
 #define SSL_RECORD_MINIMUM_LENGTH       6
 
-#define HAS_SPACE(n) ((uint32_t)((input) + (n) - (initial_input)) > (uint32_t)(input_len)) ?  0 : 1
+#define HAS_SPACE(n) ((uint32_t)((*input) + (n) - (initial_input)) > (uint32_t)(input_len)) ?  0 : 1
 
 static void SSLParserReset(SSLState *ssl_state)
 {
@@ -250,55 +250,103 @@ static void SSLSetTxDetectFlags(void *vtx, uint8_t dir, uint64_t flags)
     }
 }
 
-static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
-                                   uint32_t input_len)
+static inline int TLSDecodeHSHelloVersion(SSLState *ssl_state, uint8_t **input,
+                                          const uint32_t input_len,
+                                          const uint8_t *initial_input)
 {
-    uint8_t *initial_input = input;
+    /* Skip version */
+    *input += SSLV3_CLIENT_HELLO_VERSION_LEN;
 
-    /* only parse the message if it is complete */
-    if (input_len < ssl_state->curr_connp->message_length || input_len < 40)
-        return 0;
+    return 0;
+}
 
-    /* skip version */
-    input += SSLV3_CLIENT_HELLO_VERSION_LEN;
+static inline int TLSDecodeHSHelloRandom(SSLState *ssl_state, uint8_t **input,
+                                         const uint32_t input_len,
+                                         const uint8_t *initial_input)
+{
+    /* Skip random */
+    *input += SSLV3_CLIENT_HELLO_RANDOM_LEN;
 
-    /* skip random */
-    input += SSLV3_CLIENT_HELLO_RANDOM_LEN;
+    return 0;
+}
 
-    if (!(HAS_SPACE(1)))
-        goto invalid_length;
+static inline int TLSDecodeHSHelloSessionID(SSLState *ssl_state,
+                                            uint8_t **input,
+                                            const uint32_t input_len,
+                                            const uint8_t *initial_input)
+{
+    if (!(HAS_SPACE(1))) {
+        SCLogDebug("TLS handshake invalid length");
+        SSLSetEvent(ssl_state,
+                    TLS_DECODER_EVENT_HANDSHAKE_INVALID_LENGTH);
+        return -1;
+    }
+
+    uint8_t session_id_length = **input;
+    *input += 1;
 
-    /* skip session id */
-    uint8_t session_id_length = *(input++);
     if (session_id_length != 0) {
         ssl_state->flags |= SSL_AL_FLAG_SSL_CLIENT_SESSION_ID;
     }
 
-    input += session_id_length;
+    *input += session_id_length;
 
-    if (!(HAS_SPACE(2)))
-        goto invalid_length;
+    return 0;
+}
 
-    /* skip cipher suites */
-    uint16_t cipher_suites_length = input[0] << 8 | input[1];
-    input += 2;
+static inline int TLSDecodeHSHelloCipherSuites(SSLState *ssl_state,
+                                               uint8_t **input,
+                                               const uint32_t input_len,
+                                               const uint8_t *initial_input)
+{
+    if (!(HAS_SPACE(2))) {
+        SCLogDebug("TLS handshake invalid length");
+        SSLSetEvent(ssl_state,
+                    TLS_DECODER_EVENT_HANDSHAKE_INVALID_LENGTH);
+        return -1;
+    }
 
-    input += cipher_suites_length;
+    uint16_t cipher_suites_length = **input << 8 | *(*input + 1);
+    *input += 2;
 
-    if (!(HAS_SPACE(1)))
-        goto invalid_length;
+    /* Skip cipher suites */
+    *input += cipher_suites_length;
+
+    return 0;
+}
+
+static inline int TLSDecodeHSHelloCompressionMethods(SSLState *ssl_state,
+                                                     uint8_t **input,
+                                                     const uint32_t input_len,
+                                                     const uint8_t *initial_input)
+{
+    if (!(HAS_SPACE(1))) {
+        SCLogDebug("TLS handshake invalid length");
+        SSLSetEvent(ssl_state,
+                    TLS_DECODER_EVENT_HANDSHAKE_INVALID_LENGTH);
+        return -1;
+    }
 
-    /* skip compression methods */
-    uint8_t compression_methods_length = *(input++);
+    /* Skip compression methods */
+    uint8_t compression_methods_length = **input;
+    *input += 1;
 
-    input += compression_methods_length;
+    *input += compression_methods_length;
 
-    /* extensions are optional (RFC5246 section 7.4.1.2) */
+    return 0;
+}
+
+static inline int TLSDecodeHSHelloExtensions(SSLState *ssl_state,
+                                             uint8_t **input,
+                                             const uint32_t input_len,
+                                             const uint8_t *initial_input)
+{
+    /* Extensions are optional (RFC5246 section 7.4.1.2) */
     if (!(HAS_SPACE(2)))
         goto end;
 
-    uint16_t extensions_len = input[0] << 8 | input[1];
-    input += 2;
+    uint16_t extensions_len = **input << 8 | *(*input + 1);
+    *input += 2;
 
     if (!(HAS_SPACE(extensions_len)))
         goto invalid_length;
@@ -309,20 +357,20 @@ static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
         if (!(HAS_SPACE(2)))
             goto invalid_length;
 
-        uint16_t ext_type = input[0] << 8 | input[1];
-        input += 2;
+        uint16_t ext_type = **input << 8 | *(*input + 1);
+        *input += 2;
 
         if (!(HAS_SPACE(2)))
             goto invalid_length;
 
-        uint16_t ext_len = input[0] << 8 | input[1];
-        input += 2;
+        uint16_t ext_len = **input << 8 | *(*input + 1);
+        *input += 2;
 
         switch (ext_type) {
             case SSL_EXTENSION_SNI:
             {
-                /* there must not be more than one extension of the same
-                   type (RFC5246 section 7.4.1.4) */
+                /* There must not be more than one extension of the same
+                   type (RFC5246 section 7.4.1.4). */
                 if (ssl_state->curr_connp->sni) {
                     SCLogDebug("Multiple SNI extensions");
                     SSLSetEvent(ssl_state,
@@ -330,16 +378,17 @@ static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
                     return -1;
                 }
 
-                /* skip sni_list_length */
-                input += 2;
+                /* Skip sni_list_length */
+                *input += 2;
 
                 if (!(HAS_SPACE(1)))
                     goto invalid_length;
 
-                uint8_t sni_type = *(input++);
+                uint8_t sni_type = **input;
+                *input += 1;
 
-                /* currently the only type allowed is host_name
-                   (RFC6066 section 3) */
+                /* Currently the only type allowed is host_name
+                   (RFC6066 section 3). */
                 if (sni_type != SSL_SNI_TYPE_HOST_NAME) {
                     SCLogDebug("Unknown SNI type");
                     SSLSetEvent(ssl_state,
@@ -350,15 +399,15 @@ static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
                 if (!(HAS_SPACE(2)))
                     goto invalid_length;
 
-                uint16_t sni_len = input[0] << 8 | input[1];
-                input += 2;
+                uint16_t sni_len = **input << 8 | *(*input + 1);
+                *input += 2;
 
                 if (!(HAS_SPACE(sni_len)))
                     goto invalid_length;
 
                 /* host_name contains the fully qualified domain name,
                    and should therefore be limited by the maximum domain
-                   name length */
+                   name length. */
                 if (sni_len > 255) {
                     SCLogDebug("SNI length >255");
                     SSLSetEvent(ssl_state,
@@ -370,17 +419,18 @@ static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
                 ssl_state->curr_connp->sni = SCMalloc(sni_strlen);
 
                 if (unlikely(ssl_state->curr_connp->sni == NULL))
-                    goto end;
+                    return -1;
 
-                memcpy(ssl_state->curr_connp->sni, input, sni_strlen - 1);
+                memcpy(ssl_state->curr_connp->sni, *input, sni_strlen - 1);
                 ssl_state->curr_connp->sni[sni_strlen-1] = 0;
 
-                input += sni_len;
+                *input += sni_len;
                 break;
             }
+
             default:
             {
-                input += ext_len;
+                *input += ext_len;
                 break;
             }
         }
@@ -393,7 +443,48 @@ end:
 invalid_length:
     SCLogDebug("TLS handshake invalid length");
     SSLSetEvent(ssl_state,
-            TLS_DECODER_EVENT_HANDSHAKE_INVALID_LENGTH);
+                TLS_DECODER_EVENT_HANDSHAKE_INVALID_LENGTH);
+    return -1;
+}
+
+static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
+                                   uint32_t input_len)
+{
+    int rc;
+    uint8_t *initial_input = input;
+
+    /* Only parse the message if it is complete */
+    if (input_len < ssl_state->curr_connp->message_length || input_len < 40)
+        goto end;
+
+    rc = TLSDecodeHSHelloVersion(ssl_state, &input, input_len, initial_input);
+    if (rc != 0)
+        goto end;
+
+    rc = TLSDecodeHSHelloRandom(ssl_state, &input, input_len, initial_input);
+    if (rc != 0)
+        goto end;
+
+    rc = TLSDecodeHSHelloSessionID(ssl_state, &input, input_len, initial_input);
+    if (rc != 0)
+        goto end;
+
+    rc = TLSDecodeHSHelloCipherSuites(ssl_state, &input, input_len,
+                                      initial_input);
+    if (rc != 0)
+        goto end;
+
+    rc = TLSDecodeHSHelloCompressionMethods(ssl_state, &input, input_len,
+                                            initial_input);
+    if (rc != 0)
+        goto end;
+
+    rc = TLSDecodeHSHelloExtensions(ssl_state, &input, input_len,
+                                    initial_input);
+    if (rc != 0)
+        goto end;
+
+end:
     return 0;
 }