]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
ssl: improve error checking
authorVictor Julien <victor@inliniac.net>
Fri, 3 Apr 2020 13:34:41 +0000 (15:34 +0200)
committerVictor Julien <victor@inliniac.net>
Tue, 28 Apr 2020 12:07:54 +0000 (14:07 +0200)
src/app-layer-ssl.c

index 9b7cbfc01d71de187e18fdf75cad2984f4fc8909..fee73c14628b788eac74702ff0a1350894b91a67 100644 (file)
@@ -1591,7 +1591,6 @@ static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, const uint8_t *input
                                        uint32_t input_len, uint8_t direction)
 {
     const uint8_t *initial_input = input;
-    int retval;
 
     if (input_len == 0 || ssl_state->curr_connp->bytes_processed ==
             (ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN)) {
@@ -1641,11 +1640,10 @@ static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, const uint8_t *input
             /* fall through */
     }
 
-    retval = SSLv3ParseHandshakeType(ssl_state, input, input_len, direction);
-    if (retval < 0) {
+    int retval = SSLv3ParseHandshakeType(ssl_state, input, input_len, direction);
+    if (retval < 0 || (uint32_t)retval > input_len) {
         return retval;
     }
-
     input += retval;
 
     return (input - initial_input);
@@ -1962,13 +1960,12 @@ static int SSLv2Decode(uint8_t direction, SSLState *ssl_state,
     if (ssl_state->curr_connp->bytes_processed <
             (ssl_state->curr_connp->record_lengths_length + 1)) {
         retval = SSLv2ParseRecord(direction, ssl_state, input, input_len);
-        if (retval == -1) {
+        if (retval < 0 || (uint32_t)retval > input_len) {
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSLV2_HEADER);
             return -1;
-        } else {
-            input += retval;
-            input_len -= retval;
         }
+        input += retval;
+        input_len -= retval;
     }
 
     if (input_len == 0) {
@@ -2225,7 +2222,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
 
     if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) {
         int retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
-        if (retval < 0) {
+        if (retval < 0 || (uint32_t)retval > input_len) {
             SCLogDebug("SSLv3ParseRecord returned %d", retval);
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
             return -1;
@@ -2316,41 +2313,30 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
 
             int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
                                                      input_len, direction);
-            if (retval < 0) {
+            if (retval < 0 || (uint32_t)retval > input_len) {
                 SSLSetEvent(ssl_state,
                         TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE);
                 SSLSetEvent(ssl_state,
                         TLS_DECODER_EVENT_INVALID_SSL_RECORD);
                 SCLogDebug("SSLv3ParseHandshakeProtocol returned %d", retval);
                 return -1;
-            } else {
-                if ((uint32_t)retval > input_len) {
-                    SCLogDebug("Error parsing SSLv3.x. Reseting parser "
-                               "state. Let's get outta here");
-                    SSLParserReset(ssl_state);
-                    SSLSetEvent(ssl_state,
-                            TLS_DECODER_EVENT_INVALID_SSL_RECORD);
-                    return -1;
-                }
-
-                parsed += retval;
-                input_len -= retval;
-                (void)input_len; /* for scan-build */
+            }
 
-                if (ssl_state->curr_connp->bytes_processed ==
-                        ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN) {
-                    SCLogDebug("record ready");
-                    SSLParserReset(ssl_state);
-                }
+            parsed += retval;
+            input_len -= retval;
+            (void)input_len; /* for scan-build */
 
-                SCLogDebug("trigger RAW! (post HS)");
-                AppLayerParserTriggerRawStreamReassembly(ssl_state->f,
-                        direction == 0 ? STREAM_TOSERVER : STREAM_TOCLIENT);
-                return parsed;
+            if (ssl_state->curr_connp->bytes_processed ==
+                    ssl_state->curr_connp->record_length +
+                    SSLV3_RECORD_HDR_LEN) {
+                SCLogDebug("record ready");
+                SSLParserReset(ssl_state);
             }
 
-            break;
+            SCLogDebug("trigger RAW! (post HS)");
+            AppLayerParserTriggerRawStreamReassembly(ssl_state->f,
+                    direction == 0 ? STREAM_TOSERVER : STREAM_TOCLIENT);
+            return parsed;
         }
         case SSLV3_HEARTBEAT_PROTOCOL: {
             int retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
@@ -2481,22 +2467,17 @@ static AppLayerResult SSLDecode(Flow *f, uint8_t direction, void *alstate, AppLa
             }
             int retval = SSLv2Decode(direction, ssl_state, pstate, input,
                     input_len);
-            if (retval < 0) {
+            if (retval < 0 || retval > input_len) {
                 SCLogDebug("Error parsing SSLv2. Reseting parser "
                         "state. Let's get outta here");
                 SSLParserReset(ssl_state);
                 SSLSetEvent(ssl_state,
                         TLS_DECODER_EVENT_INVALID_SSL_RECORD);
                 return APP_LAYER_OK;
-            } else if (retval > input_len) {
-                SCLogDebug("Error parsing SSLv2. Reseting parser "
-                        "state.  Let's get outta here");
-                SSLParserReset(ssl_state);
-            } else {
-                input_len -= retval;
-                input += retval;
-                SCLogDebug("SSLv2 decoder consumed %d bytes: %u left", retval, input_len);
             }
+            input_len -= retval;
+            input += retval;
+            SCLogDebug("SSLv2 decoder consumed %d bytes: %u left", retval, input_len);
         } else {
             if (ssl_state->curr_connp->bytes_processed == 0) {
                 SCLogDebug("New TLS record");
@@ -2505,26 +2486,21 @@ static AppLayerResult SSLDecode(Flow *f, uint8_t direction, void *alstate, AppLa
             }
             int retval = SSLv3Decode(direction, ssl_state, pstate, input,
                     input_len);
-            if (retval < 0) {
+            if (retval < 0 || retval > input_len) {
                 SCLogDebug("Error parsing TLS. Reseting parser "
                         "state.  Let's get outta here");
                 SSLParserReset(ssl_state);
                 return APP_LAYER_ERROR;
-            } else if (retval > input_len) {
-                SCLogDebug("Error parsing TLS. Reseting parser "
-                        "state.  Let's get outta here");
+            }
+            input_len -= retval;
+            input += retval;
+            SCLogDebug("TLS decoder consumed %d bytes: %u left", retval, input_len);
+
+            if (ssl_state->curr_connp->bytes_processed == SSLV3_RECORD_HDR_LEN
+                    && ssl_state->curr_connp->record_length == 0) {
+                SCLogDebug("TLS empty record");
+                /* empty record */
                 SSLParserReset(ssl_state);
-            } else {
-                input_len -= retval;
-                input += retval;
-                SCLogDebug("TLS decoder consumed %d bytes: %u left", retval, input_len);
-
-                if (ssl_state->curr_connp->bytes_processed == SSLV3_RECORD_HDR_LEN
-                        && ssl_state->curr_connp->record_length == 0) {
-                    SCLogDebug("TLS empty record");
-                    /* empty record */
-                    SSLParserReset(ssl_state);
-                }
             }
         }
         counter++;