]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
ssl: handshake parsing code cleanup
authorVictor Julien <victor@inliniac.net>
Fri, 21 Feb 2020 21:52:24 +0000 (22:52 +0100)
committerVictor Julien <victor@inliniac.net>
Tue, 28 Apr 2020 12:07:54 +0000 (14:07 +0200)
src/app-layer-ssl.c

index f0b474e5ac5a359f30fcbad6d5950fb91ab5990c..1ffdad5e251f440c61ea8e8e8bdd64513da3a143 100644 (file)
@@ -1406,11 +1406,82 @@ error:
     return -1;
 }
 
+static inline bool
+HaveEntireRecord(const SSLStateConnp *curr_connp, const uint32_t input_len)
+{
+    return (curr_connp->bytes_processed + input_len) >=
+                      (curr_connp->record_length + SSLV3_RECORD_HDR_LEN);
+}
+
+static inline bool
+RecordAlreadyProcessed(const SSLStateConnp *curr_connp)
+{
+    return ((curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
+            curr_connp->bytes_processed);
+}
+
+static inline int SSLv3ParseHandshakeTypeCertificate(SSLState *ssl_state,
+        const uint8_t * const initial_input,
+        const uint32_t input_len)
+{
+    if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) {
+        /* error, skip packet */
+        ssl_state->curr_connp->bytes_processed += input_len;
+        return -1;
+    }
+
+    uint32_t write_len = 0;
+    if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
+        if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
+            SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
+            return -1;
+        }
+        write_len = (ssl_state->curr_connp->record_length +
+                SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed;
+    } else {
+        write_len = input_len;
+    }
+
+    if (SafeMemcpy(ssl_state->curr_connp->trec,
+                ssl_state->curr_connp->trec_pos,
+                ssl_state->curr_connp->trec_len,
+                initial_input, 0, input_len, write_len) != 0) {
+        return -1;
+    }
+    ssl_state->curr_connp->trec_pos += write_len;
+
+    int rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
+            ssl_state->curr_connp->trec_pos);
+    if (rc > 0) {
+        /* do not return normally if the packet was fragmented:
+           we would return the size of the _entire_ message,
+           while we expect only the number of bytes parsed bytes
+           from the _current_ fragment */
+        if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
+            SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
+            return -1;
+        }
+
+        uint32_t diff = write_len -
+            (ssl_state->curr_connp->trec_pos - rc);
+        ssl_state->curr_connp->bytes_processed += diff;
+
+        ssl_state->curr_connp->trec_pos = 0;
+        ssl_state->curr_connp->handshake_type = 0;
+        ssl_state->curr_connp->hs_bytes_processed = 0;
+        ssl_state->curr_connp->message_length = 0;
+
+        return diff;
+    } else {
+        ssl_state->curr_connp->bytes_processed += write_len;
+        return write_len;
+    }
+}
+
 static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
                                    uint32_t input_len, uint8_t direction)
 {
     const uint8_t *initial_input = input;
-    uint32_t parsed = 0;
     int rc;
 
     if (input_len == 0) {
@@ -1425,7 +1496,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
             if (input_len >= ssl_state->curr_connp->message_length &&
                       input_len >= 40) {
                 rc = TLSDecodeHandshakeHello(ssl_state, input, input_len);
-
                 if (rc < 0)
                     return rc;
             }
@@ -1440,7 +1510,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
                     input_len >= 40) {
                 rc = TLSDecodeHandshakeHello(ssl_state, input,
                                              ssl_state->curr_connp->message_length);
-
                 if (rc < 0)
                     return rc;
             }
@@ -1462,72 +1531,10 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
                            "direction!");
                 break;
             }
-
-            if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) {
-                /* error, skip packet */
-                parsed += input_len;
-                (void)parsed; /* for scan-build */
-                ssl_state->curr_connp->bytes_processed += input_len;
-                return -1;
-            }
-
-            uint32_t write_len = 0;
-            if ((ssl_state->curr_connp->bytes_processed + input_len) >
-                    ssl_state->curr_connp->record_length +
-                    (SSLV3_RECORD_HDR_LEN)) {
-                if ((ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN) <
-                        ssl_state->curr_connp->bytes_processed) {
-                    SSLSetEvent(ssl_state,
-                            TLS_DECODER_EVENT_INVALID_SSL_RECORD);
-                    return -1;
-                }
-                write_len = (ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN) -
-                        ssl_state->curr_connp->bytes_processed;
-            } else {
-                write_len = input_len;
-            }
-
-            if (SafeMemcpy(ssl_state->curr_connp->trec,
-                        ssl_state->curr_connp->trec_pos,
-                        ssl_state->curr_connp->trec_len,
-                        initial_input, 0, input_len, write_len) != 0) {
-                return -1;
-            }
-            ssl_state->curr_connp->trec_pos += write_len;
-
-            rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
-                                        ssl_state->curr_connp->trec_pos);
-
-            if (rc > 0) {
-                /* do not return normally if the packet was fragmented:
-                   we would return the size of the _entire_ message,
-                   while we expect only the number of bytes parsed bytes
-                   from the _current_ fragment */
-                if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
-                    SSLSetEvent(ssl_state,
-                            TLS_DECODER_EVENT_INVALID_SSL_RECORD);
-                    return -1;
-                }
-
-                uint32_t diff = write_len -
-                        (ssl_state->curr_connp->trec_pos - rc);
-                ssl_state->curr_connp->bytes_processed += diff;
-
-                ssl_state->curr_connp->trec_pos = 0;
-                ssl_state->curr_connp->handshake_type = 0;
-                ssl_state->curr_connp->hs_bytes_processed = 0;
-                ssl_state->curr_connp->message_length = 0;
-
-                return diff;
-            } else {
-                ssl_state->curr_connp->bytes_processed += write_len;
-                parsed += write_len;
-                return parsed;
-            }
-
+            return SSLv3ParseHandshakeTypeCertificate(ssl_state,
+                    initial_input, input_len);
             break;
+
         case SSLV3_HS_HELLO_REQUEST:
         case SSLV3_HS_CERTIFICATE_REQUEST:
         case SSLV3_HS_CERTIFICATE_VERIFY:
@@ -1546,10 +1553,8 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
     ssl_state->flags |= ssl_state->current_flags;
 
     uint32_t write_len = 0;
-    if ((ssl_state->curr_connp->bytes_processed + input_len) >=
-            ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) {
-        if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
-                ssl_state->curr_connp->bytes_processed) {
+    if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
+        if (RecordAlreadyProcessed(ssl_state->curr_connp)) {
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
             return -1;
         }
@@ -1566,24 +1571,18 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
             return -1;
         }
-        parsed += ssl_state->curr_connp->message_length -
+        const uint32_t parsed = ssl_state->curr_connp->message_length -
                 ssl_state->curr_connp->trec_pos;
-
-        ssl_state->curr_connp->bytes_processed +=
-                ssl_state->curr_connp->message_length -
-                ssl_state->curr_connp->trec_pos;
-
+        ssl_state->curr_connp->bytes_processed += parsed;
         ssl_state->curr_connp->handshake_type = 0;
         ssl_state->curr_connp->hs_bytes_processed = 0;
         ssl_state->curr_connp->message_length = 0;
         ssl_state->curr_connp->trec_pos = 0;
-
         return parsed;
     } else {
         ssl_state->curr_connp->trec_pos += write_len;
         ssl_state->curr_connp->bytes_processed += write_len;
-        parsed += write_len;
-        return parsed;
+        return write_len;
     }
 }
 
@@ -2221,18 +2220,16 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
                        AppLayerParserState *pstate, const uint8_t *input,
                        uint32_t input_len)
 {
-    int retval = 0;
     uint32_t parsed = 0;
 
     if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) {
-        retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
+        int retval = SSLv3ParseRecord(direction, ssl_state, input, input_len);
         if (retval < 0) {
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
             return -1;
-        } else {
-            parsed += retval;
-            input_len -= retval;
         }
+        parsed += retval;
+        input_len -= retval;
     }
 
     if (input_len == 0) {
@@ -2296,7 +2293,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
 
             break;
 
-        case SSLV3_HANDSHAKE_PROTOCOL:
+        case SSLV3_HANDSHAKE_PROTOCOL: {
             if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC) {
                 /* In TLSv1.3, ChangeCipherSpec is only used for middlebox
                    compability (rfc8446, appendix D.4). */
@@ -2314,8 +2311,8 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
                 return -1;
             }
 
-            retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
-                                                 input_len, direction);
+            int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
+                                                     input_len, direction);
             if (retval < 0) {
                 SSLSetEvent(ssl_state,
                         TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE);
@@ -2349,15 +2346,15 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
             }
 
             break;
-
-        case SSLV3_HEARTBEAT_PROTOCOL:
-            retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
+        }
+        case SSLV3_HEARTBEAT_PROTOCOL: {
+            int retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed,
                                                  input_len, direction);
             if (retval < 0)
                 return -1;
 
             break;
-
+        }
         default:
             /* \todo fix the event from invalid rule to unknown rule */
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE);
@@ -2365,8 +2362,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
             return -1;
     }
 
-    if (input_len + ssl_state->curr_connp->bytes_processed >=
-            ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) {
+    if (HaveEntireRecord(ssl_state->curr_connp, input_len)) {
         if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) <
                 ssl_state->curr_connp->bytes_processed) {
             /* defensive checks. Something is wrong. */
@@ -2392,7 +2388,6 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state,
         ssl_state->curr_connp->bytes_processed += input_len;
         return parsed;
     }
-
 }
 
 /**