From: Victor Julien Date: Fri, 21 Feb 2020 21:52:24 +0000 (+0100) Subject: ssl: handshake parsing code cleanup X-Git-Tag: suricata-6.0.0-beta1~447 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ab44b5edacfc879917f6bc9c2060476554787396;p=thirdparty%2Fsuricata.git ssl: handshake parsing code cleanup --- diff --git a/src/app-layer-ssl.c b/src/app-layer-ssl.c index f0b474e5ac..1ffdad5e25 100644 --- a/src/app-layer-ssl.c +++ b/src/app-layer-ssl.c @@ -1406,11 +1406,82 @@ error: return -1; } +static inline bool +HaveEntireRecord(const SSLStateConnp *curr_connp, const uint32_t input_len) +{ + return (curr_connp->bytes_processed + input_len) >= + (curr_connp->record_length + SSLV3_RECORD_HDR_LEN); +} + +static inline bool +RecordAlreadyProcessed(const SSLStateConnp *curr_connp) +{ + return ((curr_connp->record_length + SSLV3_RECORD_HDR_LEN) < + curr_connp->bytes_processed); +} + +static inline int SSLv3ParseHandshakeTypeCertificate(SSLState *ssl_state, + const uint8_t * const initial_input, + const uint32_t input_len) +{ + if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) { + /* error, skip packet */ + ssl_state->curr_connp->bytes_processed += input_len; + return -1; + } + + uint32_t write_len = 0; + if (HaveEntireRecord(ssl_state->curr_connp, input_len)) { + if (RecordAlreadyProcessed(ssl_state->curr_connp)) { + SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); + return -1; + } + write_len = (ssl_state->curr_connp->record_length + + SSLV3_RECORD_HDR_LEN) - ssl_state->curr_connp->bytes_processed; + } else { + write_len = input_len; + } + + if (SafeMemcpy(ssl_state->curr_connp->trec, + ssl_state->curr_connp->trec_pos, + ssl_state->curr_connp->trec_len, + initial_input, 0, input_len, write_len) != 0) { + return -1; + } + ssl_state->curr_connp->trec_pos += write_len; + + int rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec, + ssl_state->curr_connp->trec_pos); + if (rc > 0) { + /* do not return normally if the packet was fragmented: + we would return the size of the _entire_ message, + while we expect only the number of bytes parsed bytes + from the _current_ fragment */ + if (write_len < (ssl_state->curr_connp->trec_pos - rc)) { + SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); + return -1; + } + + uint32_t diff = write_len - + (ssl_state->curr_connp->trec_pos - rc); + ssl_state->curr_connp->bytes_processed += diff; + + ssl_state->curr_connp->trec_pos = 0; + ssl_state->curr_connp->handshake_type = 0; + ssl_state->curr_connp->hs_bytes_processed = 0; + ssl_state->curr_connp->message_length = 0; + + return diff; + } else { + ssl_state->curr_connp->bytes_processed += write_len; + return write_len; + } +} + static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, uint32_t input_len, uint8_t direction) { const uint8_t *initial_input = input; - uint32_t parsed = 0; int rc; if (input_len == 0) { @@ -1425,7 +1496,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, if (input_len >= ssl_state->curr_connp->message_length && input_len >= 40) { rc = TLSDecodeHandshakeHello(ssl_state, input, input_len); - if (rc < 0) return rc; } @@ -1440,7 +1510,6 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, input_len >= 40) { rc = TLSDecodeHandshakeHello(ssl_state, input, ssl_state->curr_connp->message_length); - if (rc < 0) return rc; } @@ -1462,72 +1531,10 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, "direction!"); break; } - - if (EnsureRecordSpace(ssl_state->curr_connp, input_len) < 0) { - /* error, skip packet */ - parsed += input_len; - (void)parsed; /* for scan-build */ - ssl_state->curr_connp->bytes_processed += input_len; - return -1; - } - - uint32_t write_len = 0; - if ((ssl_state->curr_connp->bytes_processed + input_len) > - ssl_state->curr_connp->record_length + - (SSLV3_RECORD_HDR_LEN)) { - if ((ssl_state->curr_connp->record_length + - SSLV3_RECORD_HDR_LEN) < - ssl_state->curr_connp->bytes_processed) { - SSLSetEvent(ssl_state, - TLS_DECODER_EVENT_INVALID_SSL_RECORD); - return -1; - } - write_len = (ssl_state->curr_connp->record_length + - SSLV3_RECORD_HDR_LEN) - - ssl_state->curr_connp->bytes_processed; - } else { - write_len = input_len; - } - - if (SafeMemcpy(ssl_state->curr_connp->trec, - ssl_state->curr_connp->trec_pos, - ssl_state->curr_connp->trec_len, - initial_input, 0, input_len, write_len) != 0) { - return -1; - } - ssl_state->curr_connp->trec_pos += write_len; - - rc = TlsDecodeHSCertificate(ssl_state, ssl_state->curr_connp->trec, - ssl_state->curr_connp->trec_pos); - - if (rc > 0) { - /* do not return normally if the packet was fragmented: - we would return the size of the _entire_ message, - while we expect only the number of bytes parsed bytes - from the _current_ fragment */ - if (write_len < (ssl_state->curr_connp->trec_pos - rc)) { - SSLSetEvent(ssl_state, - TLS_DECODER_EVENT_INVALID_SSL_RECORD); - return -1; - } - - uint32_t diff = write_len - - (ssl_state->curr_connp->trec_pos - rc); - ssl_state->curr_connp->bytes_processed += diff; - - ssl_state->curr_connp->trec_pos = 0; - ssl_state->curr_connp->handshake_type = 0; - ssl_state->curr_connp->hs_bytes_processed = 0; - ssl_state->curr_connp->message_length = 0; - - return diff; - } else { - ssl_state->curr_connp->bytes_processed += write_len; - parsed += write_len; - return parsed; - } - + return SSLv3ParseHandshakeTypeCertificate(ssl_state, + initial_input, input_len); break; + case SSLV3_HS_HELLO_REQUEST: case SSLV3_HS_CERTIFICATE_REQUEST: case SSLV3_HS_CERTIFICATE_VERIFY: @@ -1546,10 +1553,8 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, ssl_state->flags |= ssl_state->current_flags; uint32_t write_len = 0; - if ((ssl_state->curr_connp->bytes_processed + input_len) >= - ssl_state->curr_connp->record_length + (SSLV3_RECORD_HDR_LEN)) { - if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) < - ssl_state->curr_connp->bytes_processed) { + if (HaveEntireRecord(ssl_state->curr_connp, input_len)) { + if (RecordAlreadyProcessed(ssl_state->curr_connp)) { SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); return -1; } @@ -1566,24 +1571,18 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, const uint8_t *input, SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_SSL_RECORD); return -1; } - parsed += ssl_state->curr_connp->message_length - + const uint32_t parsed = ssl_state->curr_connp->message_length - ssl_state->curr_connp->trec_pos; - - ssl_state->curr_connp->bytes_processed += - ssl_state->curr_connp->message_length - - ssl_state->curr_connp->trec_pos; - + ssl_state->curr_connp->bytes_processed += parsed; ssl_state->curr_connp->handshake_type = 0; ssl_state->curr_connp->hs_bytes_processed = 0; ssl_state->curr_connp->message_length = 0; ssl_state->curr_connp->trec_pos = 0; - return parsed; } else { ssl_state->curr_connp->trec_pos += write_len; ssl_state->curr_connp->bytes_processed += write_len; - parsed += write_len; - return parsed; + return write_len; } } @@ -2221,18 +2220,16 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, AppLayerParserState *pstate, const uint8_t *input, uint32_t input_len) { - int retval = 0; uint32_t parsed = 0; if (ssl_state->curr_connp->bytes_processed < SSLV3_RECORD_HDR_LEN) { - retval = SSLv3ParseRecord(direction, ssl_state, input, input_len); + int retval = SSLv3ParseRecord(direction, ssl_state, input, input_len); if (retval < 0) { SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_TLS_HEADER); return -1; - } else { - parsed += retval; - input_len -= retval; } + parsed += retval; + input_len -= retval; } if (input_len == 0) { @@ -2296,7 +2293,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, break; - case SSLV3_HANDSHAKE_PROTOCOL: + case SSLV3_HANDSHAKE_PROTOCOL: { if (ssl_state->flags & SSL_AL_FLAG_CHANGE_CIPHER_SPEC) { /* In TLSv1.3, ChangeCipherSpec is only used for middlebox compability (rfc8446, appendix D.4). */ @@ -2314,8 +2311,8 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, return -1; } - retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed, - input_len, direction); + int retval = SSLv3ParseHandshakeProtocol(ssl_state, input + parsed, + input_len, direction); if (retval < 0) { SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_HANDSHAKE_MESSAGE); @@ -2349,15 +2346,15 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, } break; - - case SSLV3_HEARTBEAT_PROTOCOL: - retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed, + } + case SSLV3_HEARTBEAT_PROTOCOL: { + int retval = SSLv3ParseHeartbeatProtocol(ssl_state, input + parsed, input_len, direction); if (retval < 0) return -1; break; - + } default: /* \todo fix the event from invalid rule to unknown rule */ SSLSetEvent(ssl_state, TLS_DECODER_EVENT_INVALID_RECORD_TYPE); @@ -2365,8 +2362,7 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, return -1; } - if (input_len + ssl_state->curr_connp->bytes_processed >= - ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) { + if (HaveEntireRecord(ssl_state->curr_connp, input_len)) { if ((ssl_state->curr_connp->record_length + SSLV3_RECORD_HDR_LEN) < ssl_state->curr_connp->bytes_processed) { /* defensive checks. Something is wrong. */ @@ -2392,7 +2388,6 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, ssl_state->curr_connp->bytes_processed += input_len; return parsed; } - } /**