]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
fix for #725.
authorAnoop Saldanha <anoopsaldanha@gmail.com>
Mon, 21 Jan 2013 06:38:25 +0000 (12:08 +0530)
committerVictor Julien <victor@inliniac.net>
Wed, 6 Mar 2013 09:48:48 +0000 (10:48 +0100)
Update trec_len, trec_pos to 32 bits from 16 bits.
Handle handshakes that are fragmented across records.

src/app-layer-ssl.c
src/app-layer-ssl.h

index eca710872e65c921e9a5e3da93ed11a840f50a2d..141f1360fc39df9f543c2056882c6821415e982c 100644 (file)
@@ -112,7 +112,6 @@ SslConfig ssl_config;
 static void SSLParserReset(SSLState *ssl_state)
 {
     ssl_state->curr_connp->bytes_processed = 0;
-    ssl_state->curr_connp->message_start = 0;
 }
 
 static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
@@ -158,10 +157,16 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
                 /* error, skip packet */
                 parsed += input_len;
                 ssl_state->curr_connp->bytes_processed += input_len;
-                break;
+                return -1;
+            }
+            uint32_t write_len = 0;
+            if ((ssl_state->curr_connp->bytes_processed + input_len) > ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) {
+                write_len = (ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed;
+            } else {
+                write_len = input_len;
             }
-            memcpy(ssl_state->curr_connp->trec + ssl_state->curr_connp->trec_pos, initial_input, input_len);
-            ssl_state->curr_connp->trec_pos += input_len;
+            memcpy(ssl_state->curr_connp->trec + ssl_state->curr_connp->trec_pos, initial_input, write_len);
+            ssl_state->curr_connp->trec_pos += write_len;
 
             rc = DecodeTLSHandshakeServerCertificate(ssl_state, ssl_state->curr_connp->trec, ssl_state->curr_connp->trec_pos);
             if (rc > 0) {
@@ -170,10 +175,21 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
                  * while we expect only the number of bytes parsed bytes
                  * from the *current* fragment
                  */
-                uint32_t diff = input_len - (ssl_state->curr_connp->trec_pos - rc);
+                uint32_t diff = write_len - (ssl_state->curr_connp->trec_pos - rc);
                 ssl_state->curr_connp->bytes_processed += diff;
+
+                ssl_state->curr_connp->trec_pos = 0;
+                ssl_state->curr_connp->handshake_type = 0;
+                ssl_state->curr_connp->hs_bytes_processed = 0;
+                ssl_state->curr_connp->message_length = 0;
+
                 return diff;
+            } else {
+                ssl_state->curr_connp->bytes_processed += write_len;
+                parsed += write_len;
+                return parsed;
             }
+
             break;
         case SSLV3_HS_HELLO_REQUEST:
         case SSLV3_HS_CERTIFICATE_REQUEST:
@@ -186,18 +202,29 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input,
             break;
     }
 
-    /* skip the rest of the current message */
-    uint32_t next_msg_offset = ssl_state->curr_connp->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->curr_connp->message_length;
-    if (ssl_state->curr_connp->bytes_processed + input_len < next_msg_offset) {
-        /* we don't have enough data */
-        parsed += input_len;
-        ssl_state->curr_connp->bytes_processed += input_len;
+    uint32_t write_len = 0;
+    if ((ssl_state->curr_connp->bytes_processed + input_len) >= ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) {
+        write_len = (ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed;
+    } else {
+        write_len = input_len;
+    }
+    if ((ssl_state->curr_connp->trec_pos + write_len) >= ssl_state->curr_connp->message_length) {
+        parsed += ssl_state->curr_connp->message_length - ssl_state->curr_connp->trec_pos;
+
+        ssl_state->curr_connp->bytes_processed += ssl_state->curr_connp->message_length - ssl_state->curr_connp->trec_pos;
+
+        ssl_state->curr_connp->handshake_type = 0;
+        ssl_state->curr_connp->hs_bytes_processed = 0;
+        ssl_state->curr_connp->message_length = 0;
+        ssl_state->curr_connp->trec_pos = 0;
+
+        return parsed;
+    } else {
+        ssl_state->curr_connp->trec_pos += write_len;
+        ssl_state->curr_connp->bytes_processed += write_len;
+        parsed += write_len;
         return parsed;
     }
-    uint32_t diff = next_msg_offset - ssl_state->curr_connp->bytes_processed;
-    parsed += diff;
-    ssl_state->curr_connp->bytes_processed += diff;
-    return parsed;
 }
 
 static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, uint8_t *input,
@@ -210,29 +237,29 @@ static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, uint8_t *input,
         return 0;
     }
 
-    if (ssl_state->curr_connp->message_start == 0) {
-        ssl_state->curr_connp->message_start = SSLV3_RECORD_HDR_LEN;
-    }
-
-    switch (ssl_state->curr_connp->bytes_processed - ssl_state->curr_connp->message_start) {
+    switch (ssl_state->curr_connp->hs_bytes_processed) {
         case 0:
             ssl_state->curr_connp->handshake_type = *(input++);
             ssl_state->curr_connp->bytes_processed++;
+            ssl_state->curr_connp->hs_bytes_processed++;
             if (--input_len == 0)
                 break;
         case 1:
             ssl_state->curr_connp->message_length = *(input++) << 16;
             ssl_state->curr_connp->bytes_processed++;
+            ssl_state->curr_connp->hs_bytes_processed++;
             if (--input_len == 0)
                 break;
         case 2:
             ssl_state->curr_connp->message_length |= *(input++) << 8;
             ssl_state->curr_connp->bytes_processed++;
+            ssl_state->curr_connp->hs_bytes_processed++;
             if (--input_len == 0)
                 break;
         case 3:
             ssl_state->curr_connp->message_length |= *(input++);
             ssl_state->curr_connp->bytes_processed++;
+            ssl_state->curr_connp->hs_bytes_processed++;
             if (--input_len == 0)
                 break;
     }
@@ -243,15 +270,6 @@ static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, uint8_t *input,
     }
     input += retval;
 
-    uint32_t next_msg_offset = ssl_state->curr_connp->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->curr_connp->message_length;
-    if (ssl_state->curr_connp->bytes_processed == next_msg_offset) {
-        ssl_state->curr_connp->handshake_type = 0;
-        ssl_state->curr_connp->message_length = 0;
-        ssl_state->curr_connp->message_start = next_msg_offset;
-    } else if (ssl_state->curr_connp->bytes_processed > next_msg_offset) {
-        return -1;
-    }
-
     return (input - initial_input);
 }
 
@@ -648,6 +666,11 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
             if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC)
                 break;
 
+            if (ssl_state->curr_connp->record_length < 4) {
+                SSLParserReset(ssl_state);
+                return -1;
+            }
+
             retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed, input_len);
             if (retval < 0) {
                 AppLayerDecoderEventsSetEvent(ssl_state->f, TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE);
@@ -2251,11 +2274,6 @@ static int SSLParserTest15(void)
     };
     uint32_t buf1_len = sizeof(buf1);
 
-    uint8_t buf2[] = {
-        0x16, 0x03, 0x00, 0x00, 0x00,
-    };
-    uint32_t buf2_len = sizeof(buf2);
-
     TcpSession ssn;
 
     memset(&f, 0, sizeof(f));
@@ -2265,26 +2283,12 @@ static int SSLParserTest15(void)
     StreamTcpInitConfig(TRUE);
 
     int r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf1, buf1_len);
-    if (r != 0) {
-        printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
-        result = 0;
-        goto end;
-    }
-
-    r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf2, buf2_len);
-    if (r != 0) {
+    if (r == 0) {
         printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
         result = 0;
         goto end;
     }
 
-    SSLState *ssl_state = f.alstate;
-    if (ssl_state == NULL) {
-        printf("no tls state: ");
-        result = 0;
-        goto end;
-    }
-
 end:
     StreamTcpFreeConfig(TRUE);
     return result;
@@ -2303,11 +2307,6 @@ static int SSLParserTest16(void)
     };
     uint32_t buf1_len = sizeof(buf1);
 
-    uint8_t buf2[] = {
-        0x16, 0x03, 0x00, 0x00, 0x00,
-    };
-    uint32_t buf2_len = sizeof(buf2);
-
     TcpSession ssn;
 
     memset(&f, 0, sizeof(f));
@@ -2317,26 +2316,12 @@ static int SSLParserTest16(void)
     StreamTcpInitConfig(TRUE);
 
     int r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf1, buf1_len);
-    if (r != 0) {
-        printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
-        result = 0;
-        goto end;
-    }
-
-    r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf2, buf2_len);
-    if (r != 0) {
+    if (r == 0) {
         printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
         result = 0;
         goto end;
     }
 
-    SSLState *ssl_state = f.alstate;
-    if (ssl_state == NULL) {
-        printf("no tls state: ");
-        result = 0;
-        goto end;
-    }
-
 end:
     StreamTcpFreeConfig(TRUE);
     return result;
@@ -2355,11 +2340,6 @@ static int SSLParserTest17(void)
     };
     uint32_t buf1_len = sizeof(buf1);
 
-    uint8_t buf2[] = {
-        0x16, 0x03, 0x00, 0x00, 0x00,
-    };
-    uint32_t buf2_len = sizeof(buf2);
-
     TcpSession ssn;
 
     memset(&f, 0, sizeof(f));
@@ -2369,26 +2349,12 @@ static int SSLParserTest17(void)
     StreamTcpInitConfig(TRUE);
 
     int r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf1, buf1_len);
-    if (r != 0) {
-        printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
-        result = 0;
-        goto end;
-    }
-
-    r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf2, buf2_len);
-    if (r != 0) {
+    if (r == 0) {
         printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
         result = 0;
         goto end;
     }
 
-    SSLState *ssl_state = f.alstate;
-    if (ssl_state == NULL) {
-        printf("no tls state: ");
-        result = 0;
-        goto end;
-    }
-
 end:
     StreamTcpFreeConfig(TRUE);
     return result;
@@ -2511,19 +2477,12 @@ static int SSLParserTest20(void)
     StreamTcpInitConfig(TRUE);
 
     int r = AppLayerParse(NULL, &f, ALPROTO_TLS, STREAM_TOSERVER, buf1, buf1_len);
-    if (r != 0) {
+    if (r == 0) {
         printf("toserver chunk 1 returned %" PRId32 ", expected 0: ", r);
         result = 0;
         goto end;
     }
 
-    SSLState *ssl_state = f.alstate;
-    if (ssl_state == NULL) {
-        printf("no tls state: ");
-        result = 0;
-        goto end;
-    }
-
 end:
     StreamTcpFreeConfig(TRUE);
     return result;
index 9ca516f7b4e3d4138cc7b31a9d1621143b8631da..51efb91efedf64b03d59d4985b896e4b9472581b 100644 (file)
@@ -105,6 +105,8 @@ typedef struct SSLStateConnp_ {
 
     /* the no of bytes processed in the currently parsed record */
     uint16_t bytes_processed;
+    /* the no of bytes processed in the currently parsed handshake */
+    uint16_t hs_bytes_processed;
 
     /* sslv2 client hello session id length */
     uint16_t session_id_length;
@@ -123,8 +125,8 @@ typedef struct SSLStateConnp_ {
     /* buffer for the tls record.
      * We use a malloced buffer, if the record is fragmented */
     uint8_t *trec;
-    uint16_t trec_len;
-    uint16_t trec_pos;
+    uint32_t trec_len;
+    uint32_t trec_pos;
 } SSLStateConnp;
 
 /**