]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
TLS app layer: rewrite decoder to handle multiple messages in records
authorPierre Chifflier <pierre.chifflier@ssi.gouv.fr>
Sat, 3 Mar 2012 14:17:14 +0000 (15:17 +0100)
committerVictor Julien <victor@inliniac.net>
Mon, 19 Mar 2012 11:12:25 +0000 (12:12 +0100)
Since we now parse the content of the TLS messages, we need to handle
the case multiple messages are shipped in a single TLS record, and
taking care of the multiple levels of fragmentation (message, record,
and TCP).
Additionally, fix a bug where the parser state was not reset after an
empty record.

src/app-layer-ssl.c
src/app-layer-ssl.h
src/app-layer-tls-handshake.c

index 8c2421de9a2a1a574257f022e7a6934e97a723e3..457d5413be6a5788f2b1fa7190cd353a6d220d02 100644 (file)
@@ -90,10 +90,12 @@ SslConfig ssl_config;
 #define SSLV2_MT_CLIENT_CERTIFICATE   8
 
 #define SSLV3_RECORD_LEN 5
+#define SSLV3_MESSAGE_HDR_LEN 4
 
 static void SSLParserReset(SSLState *ssl_state)
 {
     ssl_state->bytes_processed = 0;
+    ssl_state->message_start = 0;
 }
 
 static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
@@ -114,11 +116,13 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
             switch (ssl_state->bytes_processed) {
                 case 9:
                     ssl_state->bytes_processed++;
+                    parsed++;
                     ssl_state->handshake_client_hello_ssl_version = *(input++) << 8;
                     if (--input_len == 0)
                         break;
                 case 10:
                     ssl_state->bytes_processed++;
+                    parsed++;
                     ssl_state->handshake_client_hello_ssl_version |= *(input++);
                     if (--input_len == 0)
                         break;
@@ -132,6 +136,7 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
             if (rc >= 0) {
                 ssl_state->bytes_processed += rc;
                 input += rc;
+                parsed += rc;
             }
             break;
 
@@ -164,17 +169,14 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
 
             rc = DecodeTLSHandshakeServerCertificate(ssl_state, ssl_state->trec, ssl_state->trec_pos);
             if (rc > 0) {
-                ssl_state->bytes_processed += rc;
-                input += rc;
-            }
-            if (rc == 0) {
-                /* packet is incomplete - do not mark as parsed */
-            }
-            if (rc < 0) {
-                /* error, skip packet */
-                parsed += input_len;
-                ssl_state->bytes_processed += input_len;
-                return parsed;
+                /* do not return normally if the packet was fragmented:
+                 * we would return the size of the *entire* message,
+                 * while we expect only the number of bytes parsed bytes
+                 * from the *current* fragment
+                 */
+                uint32_t diff = input_len - (ssl_state->trec_pos - rc);
+                ssl_state->bytes_processed += diff;
+                return diff;
             }
             break;
         case SSLV3_HS_HELLO_REQUEST:
@@ -188,74 +190,69 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
             break;
     }
 
-    /* looks like we have another record */
-    parsed += (input - initial_input);
-    if ((input_len + ssl_state->bytes_processed) >= ssl_state->record_length + SSLV3_RECORD_LEN) {
-        uint32_t diff = ssl_state->record_length + SSLV3_RECORD_LEN - ssl_state->bytes_processed;
-        parsed += diff;
-        ssl_state->bytes_processed += diff;
-        return parsed;
-
-        /* we still don't have the entire record for the one we are
-         * currently parsing */
-    } else {
+    /* skip the rest of the current message */
+    uint32_t next_msg_offset = ssl_state->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->message_length;
+    if (ssl_state->bytes_processed + input_len < next_msg_offset) {
+        /* we don't have enough data */
         parsed += input_len;
         ssl_state->bytes_processed += input_len;
         return parsed;
     }
+    uint32_t diff = next_msg_offset - ssl_state->bytes_processed;
+    parsed += diff;
+    ssl_state->bytes_processed += diff;
+    return parsed;
 }
 
 static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, uint8_t *input,
                                        uint32_t input_len)
 {
     uint8_t *initial_input = input;
+    int retval;
 
     if (input_len == 0) {
         return 0;
     }
 
-    switch (ssl_state->bytes_processed) {
-        case 5:
-            if (input_len >= 4) {
-                ssl_state->handshake_type = *(input++);
-                // XXX we should *not* skip the next 3 bytes, they contain the Message length
-                input += 3;
-                input_len -= 4;
-                ssl_state->bytes_processed += 4;
+    if (ssl_state->message_start == 0) {
+        ssl_state->message_start = SSLV3_RECORD_LEN;
+    }
+
+    switch (ssl_state->bytes_processed - ssl_state->message_start) {
+        case 0:
+            ssl_state->handshake_type = *(input++);
+            ssl_state->bytes_processed++;
+            if (--input_len == 0)
                 break;
-            } else {
-                ssl_state->handshake_type = *(input++);
-                ssl_state->bytes_processed++;
-                if (--input_len == 0)
-                    break;
-            }
-        case 6:
+        case 1:
+            ssl_state->message_length = *(input++) << 16;
             ssl_state->bytes_processed++;
-            input++;
             if (--input_len == 0)
                 break;
-        case 7:
+        case 2:
+            ssl_state->message_length |= *(input++) << 8;
             ssl_state->bytes_processed++;
-            input++;
             if (--input_len == 0)
                 break;
-        case 8:
+        case 3:
+            ssl_state->message_length |= *(input++);
             ssl_state->bytes_processed++;
-            input++;
             if (--input_len == 0)
                 break;
     }
 
-    if (input_len == 0)
-        return (input - initial_input);
-
-    int retval = SSLv3ParseHandshakeType(ssl_state, input, input_len);
-    if (retval == -1) {
-        SCReturnInt(-1);
-    } else {
-        input += retval;
-        return (input - initial_input);
+    retval = SSLv3ParseHandshakeType(ssl_state, input, input_len);
+    if (retval < 0) {
+        SCReturnInt(retval);
+    }
+    uint32_t next_msg_offset = ssl_state->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->message_length;
+    if (ssl_state->bytes_processed >= next_msg_offset) {
+        ssl_state->handshake_type = 0;
+        ssl_state->message_length = 0;
+        ssl_state->message_start = next_msg_offset;
     }
+    input += retval;
+    return (input - initial_input);
 }
 
 static int SSLv3ParseRecord(uint8_t direction, SSLState *ssl_state,
@@ -702,6 +699,12 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
                 SCLogDebug("Error parsing SSLv3.x.  Let's get outta here");
                 return -1;
             } else {
+                if ((uint32_t)retval > input_len) {
+                    SCLogDebug("Error parsing SSLv3.x.  Reseting parser "
+                            "state.  Let's get outta here");
+                    SSLParserReset(ssl_state);
+                    return -1;
+                }
                 parsed += retval;
                 input_len -= retval;
                 if (ssl_state->bytes_processed == ssl_state->record_length + SSLV3_RECORD_LEN) {
@@ -803,6 +806,11 @@ static int SSLDecode(uint8_t direction, void *alstate, AppLayerParserState *psta
                     } else {
                         input_len -= retval;
                         input += retval;
+                        if (ssl_state->bytes_processed == SSLV3_RECORD_LEN
+                                && ssl_state->record_length == 0) {
+                            /* empty record */
+                            SSLParserReset(ssl_state);
+                        }
                     }
                 }
 
@@ -830,14 +838,24 @@ static int SSLDecode(uint8_t direction, void *alstate, AppLayerParserState *psta
                                "previously left off");
                     retval = SSLv3Decode(direction, ssl_state, pstate, input,
                                          input_len);
-                    if (retval == -1) {
+                    if (retval < 0) {
                         SCLogDebug("Error parsing SSLv3.x.  Reseting parser "
                                    "state.  Let's get outta here");
                         SSLParserReset(ssl_state);
                         return 0;
                     } else {
+                        if ((uint32_t)retval > input_len) {
+                            SCLogDebug("Error parsing SSLv3.x.  Reseting parser "
+                                       "state.  Let's get outta here");
+                            SSLParserReset(ssl_state);
+                        }
                         input_len -= retval;
                         input += retval;
+                        if (ssl_state->bytes_processed == SSLV3_RECORD_LEN
+                                && ssl_state->record_length == 0) {
+                            /* empty record */
+                            SSLParserReset(ssl_state);
+                        }
                     }
                 }
 
index c8aaed1f9f05b133e79a2fad5d53a4ba9839f3db..7853b3f567d20b7fe6cf0859c84d44a63dc788f7 100644 (file)
@@ -69,6 +69,10 @@ typedef struct SSLState_ {
     /* record length's length for SSLv2 */
     uint32_t record_lengths_length;
 
+    /* offset of the beginning of the current message (including header) */
+    uint32_t message_start;
+    uint32_t message_length;
+
     /* holds some state flags we need */
     uint32_t flags;
 
index 56dc36e75cc21a7057e43ee1a73b9724d3de8b0e..88282ca0a77150048c2af536fb8d0673f5deb52f 100644 (file)
@@ -84,8 +84,7 @@ int DecodeTLSHandshakeServerHello(SSLState *ssl_state, uint8_t *input, uint32_t
 
     SCLogDebug("TLS Handshake Version %.4x Cipher %d Compression %d\n", version, ciphersuite, compressionmethod);
 
-    /* return the message length (TLS record - (handshake type + length)) */
-    return ssl_state->record_length-4;
+    return ssl_state->message_length;
 }
 
 int DecodeTLSHandshakeServerCertificate(SSLState *ssl_state, uint8_t *input, uint32_t input_len)