]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
tls: use incomplete API to get full TLS records
authorVictor Julien <vjulien@oisf.net>
Fri, 5 Aug 2022 10:03:37 +0000 (12:03 +0200)
committerVictor Julien <vjulien@oisf.net>
Wed, 21 Sep 2022 04:43:47 +0000 (06:43 +0200)
The TLS record header is parsed in streaming mode still, but once the
record size is known we tell the app-layer API to give us the full
record.

Ticket: #5481

src/app-layer-ssl.c

index f2fbe8038438924c847b5facfbf5020df988a883..b867e85ea106e02dcfa1a97a7213c632a5da418d 100644 (file)
@@ -209,6 +209,26 @@ SslConfig ssl_config;
 
 #define HAS_SPACE(n) ((uint64_t)(input - initial_input) + (uint64_t)(n) <= (uint64_t)(input_len))
 
+struct SSLDecoderResult {
+    int retval;      // nr bytes consumed from input, or < 0 on error
+    uint32_t needed; // more bytes needed
+};
+#define SSL_DECODER_ERROR(e)                                                                       \
+    (struct SSLDecoderResult)                                                                      \
+    {                                                                                              \
+        (e), 0                                                                                     \
+    }
+#define SSL_DECODER_OK(c)                                                                          \
+    (struct SSLDecoderResult)                                                                      \
+    {                                                                                              \
+        (c), 0                                                                                     \
+    }
+#define SSL_DECODER_INCOMPLETE(c, n)                                                               \
+    (struct SSLDecoderResult)                                                                      \
+    {                                                                                              \
+        (c), (n)                                                                                   \
+    }
+
 static inline int SafeMemcpy(void *dst, size_t dst_offset, size_t dst_size,
         const void *src, size_t src_offset, size_t src_size, size_t src_tocopy) WARN_UNUSED;
 
@@ -2316,8 +2336,9 @@ static int SSLv2Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
     }
 }
 
-static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserState *pstate,
-        const uint8_t *input, const uint32_t input_len, const StreamSlice stream_slice)
+static struct SSLDecoderResult SSLv3Decode(uint8_t direction, SSLState *ssl_state,
+        AppLayerParserState *pstate, const uint8_t *input, const uint32_t input_len,
+        const StreamSlice stream_slice)
 {
     uint32_t parsed = 0;
     uint32_t record_len; /* slice of input_len for the current record */
@@ -2328,7 +2349,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
             DEBUG_VALIDATE_BUG_ON(retval > (int)input_len);
             SCLogDebug("SSLv3ParseRecord returned %d", retval);
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
-            return -1;
+            return SSL_DECODER_ERROR(-1);
         }
         SCLogDebug("%s input %p record_length %u", (direction == 0) ? "toserver" : "toclient",
                 input, ssl_state->curr_connp->record_length);
@@ -2349,15 +2370,22 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
     SCLogDebug("record length %u processed %u got %u",
             ssl_state->curr_connp->record_length, ssl_state->curr_connp->bytes_processed, record_len);
 
+    if (ssl_state->curr_connp->record_length > input_len - parsed) {
+        uint32_t needed = ssl_state->curr_connp->record_length;
+        SCLogDebug("record len %u input_len %u parsed %u: need %u bytes more data",
+                ssl_state->curr_connp->record_length, input_len, parsed, needed);
+        return SSL_DECODER_INCOMPLETE(parsed, needed);
+    }
+
     if (record_len == 0) {
-        return parsed;
+        return SSL_DECODER_OK(parsed);
     }
 
     /* record_length should never be zero */
     if (ssl_state->curr_connp->record_length == 0) {
         SCLogDebug("SSLv3 Record length is 0");
         SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER);
-        return -1;
+        return SSL_DECODER_ERROR(-1);
     }
     AppLayerFrameNewByPointer(ssl_state->f, &stream_slice, input + parsed,
             ssl_state->curr_connp->record_length, direction, TLS_FRAME_DATA);
@@ -2434,7 +2462,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
                 SSLParserReset(ssl_state);
                 SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
                 SCLogDebug("record len < 4 => %u", ssl_state->curr_connp->record_length);
-                return -1;
+                return SSL_DECODER_ERROR(-1);
             }
 
             int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed,
@@ -2446,7 +2474,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
                 SSLSetEvent(ssl_state,
                         TLS_DECODER_EVENT_INVALID_SSL_RECORD);
                 SCLogDebug("SSLv3ParseHandshakeProtocol returned %d", retval);
-                return -1;
+                return SSL_DECODER_ERROR(-1);
             }
             SCLogDebug("retval %d", retval);
 
@@ -2468,7 +2496,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
             SCLogDebug("trigger RAW! (post HS)");
             AppLayerParserTriggerRawStreamReassembly(ssl_state->f,
                     direction == 0 ? STREAM_TOSERVER : STREAM_TOCLIENT);
-            return parsed;
+            return SSL_DECODER_OK(parsed);
         }
         case SSLV3_HEARTBEAT_PROTOCOL: {
             AppLayerFrameNewByPointer(ssl_state->f, &stream_slice, input + parsed,
@@ -2477,7 +2505,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
                                                  record_len, direction);
             if (retval < 0) {
                 SCLogDebug("SSLv3ParseHeartbeatProtocol returned %d", retval);
-                return -1;
+                return SSL_DECODER_ERROR(-1);
             }
             break;
         }
@@ -2486,7 +2514,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE);
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
             SCLogDebug("unsupported record type");
-            return -1;
+            return SSL_DECODER_ERROR(-1);
     }
 
     if (HaveEntireRecord(ssl_state->curr_connp, record_len)) {
@@ -2498,7 +2526,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
             /* defensive checks. Something is wrong. */
             SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD);
             SCLogDebug("defensive checks. Something is wrong.");
-            return -1;
+            return SSL_DECODER_ERROR(-1);
         }
 
         SCLogDebug("record complete, trigger RAW");
@@ -2511,15 +2539,15 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserSta
         parsed += diff;
         SSLParserReset(ssl_state);
         ValidateRecordState(ssl_state->curr_connp);
-        return parsed;
+        return SSL_DECODER_OK(parsed);
 
-    /* we still don't have the entire record for the one we are
-       currently parsing */
+        /* we still don't have the entire record for the one we are
+           currently parsing */
     } else {
         parsed += record_len;
         ssl_state->curr_connp->bytes_processed += record_len;
         ValidateRecordState(ssl_state->curr_connp);
-        return parsed;
+        return SSL_DECODER_OK(parsed);
     }
 }
 
@@ -2548,6 +2576,7 @@ static AppLayerResult SSLDecode(Flow *f, uint8_t direction, void *alstate,
     uint32_t counter = 0;
     ssl_state->f = f;
     const uint8_t *input = StreamSliceGetData(&stream_slice);
+    const uint8_t *init_input = input;
     int32_t input_len = (int32_t)StreamSliceGetDataLen(&stream_slice);
 
     if (input == NULL &&
@@ -2626,17 +2655,23 @@ static AppLayerResult SSLDecode(Flow *f, uint8_t direction, void *alstate,
                 SCLogDebug("Continuing parsing TLS record: record_length %u, bytes_processed %u",
                         ssl_state->curr_connp->record_length, ssl_state->curr_connp->bytes_processed);
             }
-            int retval = SSLv3Decode(direction, ssl_state, pstate, input, input_len, stream_slice);
-            if (retval < 0 || retval > input_len) {
-                DEBUG_VALIDATE_BUG_ON(retval > input_len);
+            struct SSLDecoderResult r =
+                    SSLv3Decode(direction, ssl_state, pstate, input, input_len, stream_slice);
+            if (r.retval < 0 || r.retval > input_len) {
+                DEBUG_VALIDATE_BUG_ON(r.retval > input_len);
                 SCLogDebug("Error parsing TLS. Reseting parser "
                         "state.  Let's get outta here");
                 SSLParserReset(ssl_state);
                 return APP_LAYER_ERROR;
+            } else if (r.needed) {
+                input += r.retval;
+                SCLogDebug("returning consumed %" PRIuMAX " needed %u",
+                        (uintmax_t)(input - init_input), r.needed);
+                SCReturnStruct(APP_LAYER_INCOMPLETE(input - init_input, r.needed));
             }
-            input_len -= retval;
-            input += retval;
-            SCLogDebug("TLS decoder consumed %d bytes: %u left", retval, input_len);
+            input_len -= r.retval;
+            input += r.retval;
+            SCLogDebug("TLS decoder consumed %d bytes: %u left", r.retval, input_len);
 
             if (ssl_state->curr_connp->bytes_processed == SSLV3_RECORD_HDR_LEN
                     && ssl_state->curr_connp->record_length == 0) {