]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
ssl: cert parsing hardening
authorJeff Lucovsky <jeff@lucovsky.org>
Thu, 10 Mar 2022 14:09:57 +0000 (15:09 +0100)
committerVictor Julien <vjulien@oisf.net>
Thu, 21 Apr 2022 05:38:50 +0000 (07:38 +0200)
Limit initial memory allocated for certificates. General parsing
hardening and checking improvements.

Based on commits:
862e84877ff262cd4b8c4b191a8710f94f63fcf7
3ed188e0bcb6f4ae5b6c5eafdd75ce1f8e3d7246

Bug: #5193.

src/app-layer-ssl.c

index 2e8d4f39c3765259b502bcaefc2e685a0551e19e..a774b470894f358aa60c14f3a16311e132fbce15 100644 (file)
@@ -1433,10 +1433,176 @@ end:
     return 0;
 }
 
+/** \internal
+ *  \brief Get Certificates len so we can know now much (more) we need to buffer
+ *  If we already have a few bytes queued up in 'trec' we use those or a mix of
+ *  those with the input.
+ */
+static uint32_t GetCertsLen(SSLStateConnp *curr_connp, const uint8_t *input,
+        const uint32_t input_len)
+{
+    if (curr_connp->trec != NULL && curr_connp->trec_pos > 0) {
+        if (curr_connp->trec_pos >= 3) {
+            const uint8_t * const ptr = curr_connp->trec;
+            uint32_t len = (*ptr << 16 | *(ptr + 1) << 8 | *(ptr + 2)) + 3;
+            SCLogDebug("length %u (trec)", len);
+            return len;
+        } else if (input_len + curr_connp->trec_pos >= 3) {
+            uint8_t buf[3] = { 0, 0, 0, }; // init to 0 to make scan-build happy
+            uint32_t i = 0;
+            for (uint32_t x = 0; x < curr_connp->trec_pos && i < 3;  x++) {
+                buf[i++] = curr_connp->trec[x];
+            }
+            for (uint32_t x = 0; x < input_len && i < 3;  x++) {
+                buf[i++] = input[x];
+            }
+            uint32_t len = (buf[0] << 16 | buf[1] << 8 | buf[2]) + 3;
+            SCLogDebug("length %u (part trec, part input)", len);
+            return len;
+        }
+        return 0;
+    } else if (input_len >= 3) {
+        uint32_t len = (*input << 16 | *(input + 1) << 8 | *(input + 2)) + 3;
+        SCLogDebug("length %u (input)", len);
+        return len;
+    } else {
+        SCLogDebug("length unknown (input len %u)", input_len);
+        return 0;
+    }
+}
+
+// For certificates whose size is bigger than this,
+// we do not allocate all the required memory straight away,
+// to avoid DOS by RAM exhaustion, but we will allocate
+// this memory once a consequent part of the certificate has been seen.
+#define SSL_CERT_MAX_FIRST_ALLOC 65536 // 0x10000
+
+/** \internal
+ *  \brief setup or grow the `trec` space in the connp
+ */
+static int EnsureRecordSpace(SSLStateConnp *curr_connp, const uint8_t * const input,
+        const uint32_t input_len)
+{
+    ValidateTrecBuffer(curr_connp);
+
+    uint32_t certs_len = GetCertsLen(curr_connp, input, input_len);
+    if (certs_len == 0) {
+        SCLogDebug("cert_len unknown still, create small buffer to start");
+        certs_len = 256;
+    }
+    // Limit in a first time allocation for very large certificates
+    if (certs_len > SSL_CERT_MAX_FIRST_ALLOC && certs_len > curr_connp->trec_pos + input_len) {
+        certs_len = SSL_CERT_MAX_FIRST_ALLOC;
+    }
+
+    if (curr_connp->trec == NULL) {
+        curr_connp->trec_len = certs_len;
+        curr_connp->trec = SCMalloc(curr_connp->trec_len);
+        if (unlikely(curr_connp->trec == NULL))
+            goto error;
+    }
+
+    if (certs_len > curr_connp->trec_len) {
+        curr_connp->trec_len = certs_len;
+        void *ptmp = SCRealloc(curr_connp->trec, curr_connp->trec_len);
+        if (unlikely(ptmp == NULL)) {
+            SCFree(curr_connp->trec);
+            curr_connp->trec = NULL;
+            goto error;
+        }
+        curr_connp->trec = ptmp;
+    }
+    ValidateTrecBuffer(curr_connp);
+    return 0;
+error:
+    curr_connp->trec_len = 0;
+    curr_connp->trec_pos = 0;
+    ValidateTrecBuffer(curr_connp);
+    return -1;
+}
+
+static inline int SSLv3ParseHandshakeTypeCertificate(SSLState *ssl_state,
+        const uint8_t * const initial_input,
+        const uint32_t input_len)
+{
+    ValidateTrecBuffer(ssl_state->curr_connp);
+    const uint32_t certs_len = GetCertsLen(ssl_state->curr_connp, initial_input, input_len);
+    SCLogDebug("certs_len %u", certs_len);
+
+    if (EnsureRecordSpace(ssl_state->curr_connp, initial_input, input_len) < 0) {
+        /* error, skip input data */
+        ssl_state->curr_connp->bytes_processed += input_len;
+        return -1;
+    }
+    ValidateTrecBuffer(ssl_state->curr_connp);
+
+    uint32_t write_len = 0;
+    if (certs_len != 0 && ssl_state->curr_connp->trec_pos + input_len >= certs_len) {
+        write_len = certs_len - ssl_state->curr_connp->trec_pos;
+        /* get data left in this frag. The length field may indicate more
+         * data than just in this record. */
+        uint32_t cur_frag_left = ssl_state->curr_connp->record_length +
+                SSLV3_RECORD_HDR_LEN - ssl_state->curr_connp->bytes_processed;
+        SCLogDebug("write_len %u cur_frag_left %u", write_len, cur_frag_left);
+        write_len = MIN(write_len, cur_frag_left);
+    } else {
+        uint32_t cur_frag_left = (ssl_state->curr_connp->record_length +
+                SSLV3_RECORD_HDR_LEN - ssl_state->curr_connp->bytes_processed);
+        SCLogDebug("cur_frag_left %u", cur_frag_left);
+        write_len = MIN(input_len, cur_frag_left);
+        SCLogDebug("write_len %u", write_len);
+    }
+    if (write_len == 0) {
+        /* no (new) data, so we're done */
+        ValidateTrecBuffer(ssl_state->curr_connp);
+        return 0;
+    }
+    BUG_ON(write_len > input_len);
+
+    if (SafeMemcpy(ssl_state->curr_connp->trec,
+                ssl_state->curr_connp->trec_pos,
+                ssl_state->curr_connp->trec_len,
+                initial_input, 0, input_len, write_len) != 0) {
+        return -1;
+    }
+    ssl_state->curr_connp->trec_pos += write_len;
+    SCLogDebug("ssl_state->curr_connp->trec_pos %u", ssl_state->curr_connp->trec_pos);
+    DEBUG_VALIDATE_BUG_ON(certs_len != 0 && certs_len < ssl_state->curr_connp->trec_pos);
+
+    /* if we didn't yet get enough to determine the certs len, or we can
+     * see we didn't buffer enough for the certs, don't bother trying to
+     * parse it the data */
+    if (certs_len == 0 || certs_len > ssl_state->curr_connp->trec_pos) {
+        ssl_state->curr_connp->bytes_processed += write_len;
+        SCLogDebug("bytes_processed %u record_len %u",
+                ssl_state->curr_connp->bytes_processed, ssl_state->curr_connp->record_length);
+        ValidateTrecBuffer(ssl_state->curr_connp);
+        return write_len;
+    }
+
+    int rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
+            ssl_state->curr_connp->trec_pos);
+    SCLogDebug("rc %d", rc);
+    if (rc > 0) {
+        DEBUG_VALIDATE_BUG_ON(rc != (int)ssl_state->curr_connp->trec_pos);
+        SSLParserHSReset(ssl_state->curr_connp);
+    } else if (rc < 0) {
+        SCLogDebug("error parsing cert, reset state");
+        SSLParserHSReset(ssl_state->curr_connp);
+        /* fall through to still consume the cert bytes */
+    }
+    ssl_state->curr_connp->bytes_processed += write_len;
+    ValidateTrecBuffer(ssl_state->curr_connp);
+    return write_len;
+}
+
+/**
+ *  \retval parsed number of consumed bytes
+ *  \retval < 0 error
+ */
 static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
                                    uint32_t input_len, uint8_t direction)
 {
-    void *ptmp;
     const uint8_t *initial_input = input;
     uint32_t parsed = 0;
     int rc;
@@ -1490,90 +1656,9 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input,
                            "direction!");
                 break;
             }
-            if (ssl_state->curr_connp->trec == NULL) {
-                ssl_state->curr_connp->trec_len =
-                        2 * ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN + 1;
-                ssl_state->curr_connp->trec =
-                        SCMalloc(ssl_state->curr_connp->trec_len);
-            }
-            if (ssl_state->curr_connp->trec_pos + input_len >=
-                    ssl_state->curr_connp->trec_len) {
-                ssl_state->curr_connp->trec_len =
-                        ssl_state->curr_connp->trec_pos + 2 * input_len + 1;
-                ptmp = SCRealloc(ssl_state->curr_connp->trec,
-                        ssl_state->curr_connp->trec_len);
-
-                if (unlikely(ptmp == NULL)) {
-                    SCFree(ssl_state->curr_connp->trec);
-                }
-
-                ssl_state->curr_connp->trec = ptmp;
-            }
-            if (unlikely(ssl_state->curr_connp->trec == NULL)) {
-                ssl_state->curr_connp->trec_len = 0;
-                /* error, skip packet */
-                parsed += input_len;
-                (void)parsed; /* for scan-build */
-                ssl_state->curr_connp->bytes_processed += input_len;
-                return -1;
-            }
-
-            uint32_t write_len = 0;
-            if ((ssl_state->curr_connp->bytes_processed + input_len) >
-                    ssl_state->curr_connp->record_length +
-                    (SSLV3_RECORD_HDR_LEN)) {
-                if ((ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN) <
-                        ssl_state->curr_connp->bytes_processed) {
-                    SSLSetEvent(ssl_state,
-                            TLS_DECODER_EVENT_INVALID_SSL_RECORD);
-                    return -1;
-                }
-                write_len = (ssl_state->curr_connp->record_length +
-                        SSLV3_RECORD_HDR_LEN) -
-                        ssl_state->curr_connp->bytes_processed;
-            } else {
-                write_len = input_len;
-            }
-
-            if (SafeMemcpy(ssl_state->curr_connp->trec,
-                        ssl_state->curr_connp->trec_pos,
-                        ssl_state->curr_connp->trec_len,
-                        initial_input, 0, input_len, write_len) != 0) {
-                return -1;
-            }
-            ssl_state->curr_connp->trec_pos += write_len;
-
-            rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec,
-                                        ssl_state->curr_connp->trec_pos);
-
-            if (rc > 0) {
-                /* do not return normally if the packet was fragmented:
-                   we would return the size of the _entire_ message,
-                   while we expect only the number of bytes parsed bytes
-                   from the _current_ fragment */
-                if (write_len < (ssl_state->curr_connp->trec_pos - rc)) {
-                    SSLSetEvent(ssl_state,
-                            TLS_DECODER_EVENT_INVALID_SSL_RECORD);
-                    return -1;
-                }
-
-                uint32_t diff = write_len -
-                        (ssl_state->curr_connp->trec_pos - rc);
-                ssl_state->curr_connp->bytes_processed += diff;
-
-                ssl_state->curr_connp->trec_pos = 0;
-                ssl_state->curr_connp->handshake_type = 0;
-                ssl_state->curr_connp->hs_bytes_processed = 0;
-                ssl_state->curr_connp->message_length = 0;
-
-                return diff;
-            } else {
-                ssl_state->curr_connp->bytes_processed += write_len;
-                parsed += write_len;
-                return parsed;
-            }
+            rc = SSLv3ParseHandshakeTypeCertificate(ssl_state,
+                    initial_input, input_len);
+            return rc;
 
             break;
         case SSLV3_HS_HELLO_REQUEST: