From: Frederik Wedel-Heinen Date: Tue, 11 Jun 2024 08:51:38 +0000 (+0200) Subject: Refactor handshake msg header parsing etc. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cb86ed419a6663da8b462aaab9c4d72bf343924b;p=thirdparty%2Fopenssl.git Refactor handshake msg header parsing etc. Reviewed-by: Tomas Mraz Reviewed-by: Matt Caswell (Merged from https://github.com/openssl/openssl/pull/24607) --- diff --git a/include/internal/common.h b/include/internal/common.h index 8abe61bd6ec..c5416020ed2 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -180,10 +180,6 @@ __owur static ossl_inline int ossl_assert_int(int expr, const char *exprstr, (((unsigned long)((c)[1]))<< 8)| \ (((unsigned long)((c)[2])) )),(c)+=3) -# define l2n3(l,c) (((c)[0]=(unsigned char)(((l)>>16)&0xff), \ - (c)[1]=(unsigned char)(((l)>> 8)&0xff), \ - (c)[2]=(unsigned char)(((l) )&0xff)),(c)+=3) - #define l3n2(c,l) (l =((uint64_t)(*((c)++)))<<16, \ l|=((uint64_t)(*((c)++)))<< 8, \ l|=((uint64_t)(*((c)++)))) diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c index 05117b2c60f..2db9fb11050 100644 --- a/ssl/d1_lib.c +++ b/ssl/d1_lib.c @@ -95,7 +95,7 @@ int dtls1_new(SSL *ssl) return 0; } - d1->buffered_messages = pqueue_new(); + d1->rcvd_messages = pqueue_new(); d1->sent_messages = pqueue_new(); if (s->server) { @@ -106,8 +106,8 @@ int dtls1_new(SSL *ssl) d1->mtu = 0; d1->hello_verify_request = SSL_HVR_NONE; - if (d1->buffered_messages == NULL || d1->sent_messages == NULL) { - pqueue_free(d1->buffered_messages); + if (d1->rcvd_messages == NULL || d1->sent_messages == NULL) { + pqueue_free(d1->rcvd_messages); pqueue_free(d1->sent_messages); OPENSSL_free(d1); ssl3_free(ssl); @@ -133,7 +133,7 @@ void dtls1_clear_received_buffer(SSL_CONNECTION *s) pitem *item = NULL; hm_fragment *frag = NULL; - while ((item = pqueue_pop(s->d1->buffered_messages)) != NULL) { + while ((item = pqueue_pop(s->d1->rcvd_messages)) != NULL) { frag = (hm_fragment *)item->data; dtls1_hm_fragment_free(frag); pitem_free(item); @@ -143,22 +143,21 @@ void dtls1_clear_received_buffer(SSL_CONNECTION *s) void dtls1_clear_sent_buffer(SSL_CONNECTION *s) { pitem *item = NULL; - hm_fragment *frag = NULL; while ((item = pqueue_pop(s->d1->sent_messages)) != NULL) { - frag = (hm_fragment *)item->data; + dtls_sent_msg *sent_msg = (dtls_sent_msg *)item->data; - if (frag->msg_header.is_ccs - && frag->msg_header.saved_retransmit_state.wrlmethod != NULL - && s->rlayer.wrl != frag->msg_header.saved_retransmit_state.wrl) { + if (sent_msg->record_type == SSL3_RT_CHANGE_CIPHER_SPEC + && sent_msg->saved_retransmit_state.wrlmethod != NULL + && s->rlayer.wrl != sent_msg->saved_retransmit_state.wrl) { /* * If we're freeing the CCS then we're done with the old wrl and it * can bee freed */ - frag->msg_header.saved_retransmit_state.wrlmethod->free(frag->msg_header.saved_retransmit_state.wrl); + sent_msg->saved_retransmit_state.wrlmethod->free(sent_msg->saved_retransmit_state.wrl); } - dtls1_hm_fragment_free(frag); + dtls1_sent_msg_free(sent_msg); pitem_free(item); } } @@ -173,7 +172,7 @@ void dtls1_free(SSL *ssl) if (s->d1 != NULL) { dtls1_clear_queues(s); - pqueue_free(s->d1->buffered_messages); + pqueue_free(s->d1->rcvd_messages); pqueue_free(s->d1->sent_messages); } @@ -187,7 +186,7 @@ void dtls1_free(SSL *ssl) int dtls1_clear(SSL *ssl) { - pqueue *buffered_messages; + pqueue *rcvd_messages; pqueue *sent_messages; size_t mtu; size_t link_mtu; @@ -202,7 +201,7 @@ int dtls1_clear(SSL *ssl) if (s->d1) { DTLS_timer_cb timer_cb = s->d1->timer_cb; - buffered_messages = s->d1->buffered_messages; + rcvd_messages = s->d1->rcvd_messages; sent_messages = s->d1->sent_messages; mtu = s->d1->mtu; link_mtu = s->d1->link_mtu; @@ -223,7 +222,7 @@ int dtls1_clear(SSL *ssl) s->d1->link_mtu = link_mtu; } - s->d1->buffered_messages = buffered_messages; + s->d1->rcvd_messages = rcvd_messages; s->d1->sent_messages = sent_messages; } @@ -424,7 +423,7 @@ int dtls1_handle_timeout(SSL_CONNECTION *s) dtls1_start_timer(s); /* Calls SSLfatal() if required */ - return dtls1_retransmit_buffered_messages(s); + return dtls1_retransmit_sent_messages(s); } #define LISTEN_SUCCESS 2 diff --git a/ssl/record/rec_layer_d1.c b/ssl/record/rec_layer_d1.c index 22f537f005d..9cf4ebe1e5d 100644 --- a/ssl/record/rec_layer_d1.c +++ b/ssl/record/rec_layer_d1.c @@ -490,7 +490,7 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type, * Unexpected handshake message (Client Hello, or protocol violation) */ if (rr->type == SSL3_RT_HANDSHAKE && !ossl_statem_get_in_handshake(sc)) { - struct hm_header_st msg_hdr; + unsigned char msg_type; /* * This may just be a stale retransmit. Also sanity check that we have @@ -503,19 +503,19 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type, goto start; } - dtls1_get_message_header(rr->data, &msg_hdr); + msg_type = *rr->data; /* * If we are server, we may have a repeated FINISHED of the client * here, then retransmit our CCS and FINISHED. */ - if (msg_hdr.type == SSL3_MT_FINISHED) { + if (msg_type == SSL3_MT_FINISHED) { if (dtls1_check_timeout_num(sc) < 0) { /* SSLfatal) already called */ return -1; } - if (dtls1_retransmit_buffered_messages(sc) <= 0) { + if (dtls1_retransmit_sent_messages(sc) <= 0) { /* Fail if we encountered a fatal error */ if (ossl_statem_in_error(sc)) return -1; diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h index 8829327cb0f..dbe2edf4bd1 100644 --- a/ssl/ssl_local.h +++ b/ssl/ssl_local.h @@ -1940,8 +1940,6 @@ struct hm_header_st { unsigned short seq; size_t frag_off; size_t frag_len; - unsigned int is_ccs; - struct dtls1_retransmit_state saved_retransmit_state; }; typedef struct hm_fragment_st { @@ -1973,6 +1971,19 @@ pitem *pqueue_iterator(pqueue *pq); pitem *pqueue_next(piterator *iter); size_t pqueue_size(pqueue *pq); +typedef struct dtls_msg_info_st { + unsigned char msg_type; + size_t msg_body_len; + unsigned short msg_seq; +} dtls_msg_info; + +typedef struct dtls_sent_msg_st { + dtls_msg_info msg_info; + int record_type; + unsigned char *msg_buf; + struct dtls1_retransmit_state saved_retransmit_state; +} dtls_sent_msg; + typedef struct dtls1_state_st { unsigned char cookie[DTLS1_COOKIE_LENGTH]; size_t cookie_len; @@ -1981,16 +1992,16 @@ typedef struct dtls1_state_st { unsigned short handshake_write_seq; unsigned short next_handshake_write_seq; unsigned short handshake_read_seq; - /* Buffered handshake messages */ - pqueue *buffered_messages; + /* Buffered received handshake messages */ + pqueue *rcvd_messages; /* Buffered (sent) handshake records */ pqueue *sent_messages; /* Flag to indicate current HelloVerifyRequest status */ enum {SSL_HVR_NONE = 0, SSL_HVR_RECEIVED} hello_verify_request; size_t link_mtu; /* max on-the-wire DTLS packet size */ size_t mtu; /* max DTLS packet size */ - struct hm_header_st w_msg_hdr; - struct hm_header_st r_msg_hdr; + dtls_msg_info w_msg; + unsigned short r_msg_seq; /* Number of alerts received so far */ unsigned int timeout_num_alerts; /* @@ -2731,25 +2742,19 @@ __owur int ssl_get_min_max_version(const SSL_CONNECTION *s, int *min_version, int *max_version, int *real_max); __owur OSSL_TIME tls1_default_timeout(void); -__owur int dtls1_do_write(SSL_CONNECTION *s, uint8_t type); -void dtls1_set_message_header(SSL_CONNECTION *s, - unsigned char mt, - size_t len, - size_t frag_off, size_t frag_len); +__owur int dtls1_do_write(SSL_CONNECTION *s, uint8_t recordtype); int dtls1_write_app_data_bytes(SSL *s, uint8_t type, const void *buf_, size_t len, size_t *written); __owur int dtls1_read_failed(SSL_CONNECTION *s, int code); -__owur int dtls1_buffer_message(SSL_CONNECTION *s, int ccs); +__owur int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type); __owur int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found); __owur int dtls1_get_queue_priority(unsigned short seq, int is_ccs); -int dtls1_retransmit_buffered_messages(SSL_CONNECTION *s); +int dtls1_retransmit_sent_messages(SSL_CONNECTION *s); void dtls1_clear_received_buffer(SSL_CONNECTION *s); void dtls1_clear_sent_buffer(SSL_CONNECTION *s); -void dtls1_get_message_header(const unsigned char *data, - struct hm_header_st *msg_hdr); __owur OSSL_TIME dtls1_default_timeout(void); __owur int dtls1_get_timeout(const SSL_CONNECTION *s, OSSL_TIME *timeleft); __owur int dtls1_check_timeout_num(SSL_CONNECTION *s); @@ -2761,6 +2766,7 @@ __owur int dtls_raw_hello_verify_request(WPACKET *pkt, unsigned char *cookie, size_t cookie_len); __owur size_t dtls1_min_mtu(SSL_CONNECTION *s); void dtls1_hm_fragment_free(hm_fragment *frag); +void dtls1_sent_msg_free(dtls_sent_msg *msg); __owur int dtls1_query_mtu(SSL_CONNECTION *s); __owur int tls1_new(SSL *s); diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c index a9ee2cfdb40..5b7b0c4b456 100644 --- a/ssl/statem/statem_dtls.c +++ b/ssl/statem/statem_dtls.c @@ -46,55 +46,59 @@ static const unsigned char bitmask_end_values[] = { 0xff, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f }; -static void dtls1_set_message_header_int(SSL_CONNECTION *s, unsigned char mt, - size_t len, - unsigned short seq_num, - size_t frag_off, - size_t frag_len); static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, size_t *len); +static dtls_sent_msg *dtls1_sent_msg_new(size_t msg_len) +{ + dtls_sent_msg *msg = OPENSSL_malloc(sizeof(*msg) + msg_len); + + if (msg == NULL) + return NULL; + + memset(msg, 0, sizeof(*msg)); + + /* zero length msg gets msg->msg_buf == NULL */ + if (msg_len > 0) + msg->msg_buf = (unsigned char *)(msg + 1); + + return msg; +} + +void dtls1_sent_msg_free(dtls_sent_msg *msg) +{ + OPENSSL_free(msg); +} + static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) { - hm_fragment *frag = NULL; - unsigned char *buf = NULL; - unsigned char *bitmask = NULL; + const size_t bitmask_len = (reassembly ? RSMBLY_BITMASK_SIZE(frag_len) : 0); + hm_fragment *frag = OPENSSL_malloc(sizeof(*frag) + frag_len + bitmask_len); - if ((frag = OPENSSL_zalloc(sizeof(*frag))) == NULL) + if (frag == NULL) return NULL; - if (frag_len) { - if ((buf = OPENSSL_malloc(frag_len)) == NULL) { - OPENSSL_free(frag); - return NULL; - } - } + memset(frag, 0, sizeof(*frag)); /* zero length fragment gets zero frag->fragment */ - frag->fragment = buf; + if (frag_len > 0) + frag->fragment = (unsigned char *)(frag + 1); /* Initialize reassembly bitmask if necessary */ - if (reassembly) { - bitmask = OPENSSL_zalloc(RSMBLY_BITMASK_SIZE(frag_len)); - if (bitmask == NULL) { - OPENSSL_free(buf); - OPENSSL_free(frag); - return NULL; - } - } + if (bitmask_len > 0) { + if (frag->fragment == NULL) + frag->reassembly = (unsigned char *)(frag + 1); + else + frag->reassembly = frag->fragment + frag_len; - frag->reassembly = bitmask; + memset(frag->reassembly, 0, bitmask_len); + } return frag; } void dtls1_hm_fragment_free(hm_fragment *frag) { - if (!frag) - return; - - OPENSSL_free(frag->fragment); - OPENSSL_free(frag->reassembly); OPENSSL_free(frag); } @@ -136,7 +140,7 @@ static int dtls1_write_hm_header(unsigned char *msgheaderstart, * |-- header3 --||-- fragment3 --| * ......... */ -int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) +int dtls1_do_write(SSL_CONNECTION *s, uint8_t recordtype) { int ret; size_t written; @@ -146,13 +150,13 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) SSL *ssl = SSL_CONNECTION_GET_SSL(s); SSL *ussl = SSL_CONNECTION_GET_USER_SSL(s); unsigned char *data = (unsigned char *)s->init_buf->data; - unsigned short msg_seq = s->d1->w_msg_hdr.seq; + unsigned short msg_seq = s->d1->w_msg.msg_seq; unsigned char msg_type = 0; - if (type == SSL3_RT_HANDSHAKE) { + if (recordtype == SSL3_RT_HANDSHAKE) { msg_type = *data++; l3n2(data, msg_len); - } else if (ossl_assert(type == SSL3_RT_CHANGE_CIPHER_SPEC)) { + } else if (ossl_assert(recordtype == SSL3_RT_CHANGE_CIPHER_SPEC)) { msg_type = SSL3_MT_CCS; msg_len = 0; /* SSL3_RT_CHANGE_CIPHER_SPEC */ } else { @@ -167,7 +171,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) /* should have something reasonable now */ return -1; - if (s->init_off == 0 && type == SSL3_RT_HANDSHAKE) { + if (s->init_off == 0 && recordtype == SSL3_RT_HANDSHAKE) { if (!ossl_assert(s->init_num == msg_len + DTLS1_HM_HEADER_LENGTH)) return -1; } @@ -180,7 +184,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) while (s->init_num > 0) { unsigned char *msgstart; - if (type == SSL3_RT_HANDSHAKE && s->init_off > 0) { + if (recordtype == SSL3_RT_HANDSHAKE && s->init_off > 0) { /* * We must be writing a fragment other than the first one * and this is the first attempt at writing out this fragment @@ -237,7 +241,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) msgstart = (unsigned char *)&s->init_buf->data[s->init_off]; - if (type == SSL3_RT_HANDSHAKE) { + if (recordtype == SSL3_RT_HANDSHAKE) { const size_t fragoff = s->init_off; const size_t fraglen = len - DTLS1_HM_HEADER_LENGTH; @@ -251,7 +255,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) return -1; } - ret = dtls1_write_bytes(s, type, msgstart, len, &written); + ret = dtls1_write_bytes(s, recordtype, msgstart, len, &written); if (ret <= 0) { /* @@ -287,7 +291,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) assert(s->s3.tmp.new_compression != NULL || BIO_wpending(s->wbio) <= (int)s->d1->mtu); - if (type == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) { + if (recordtype == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) { /* * should not be done for 'Hello Request's, but in that case * we'll ignore the result anyway @@ -325,7 +329,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) if (written == s->init_num) { if (s->msg_callback) - s->msg_callback(1, s->version, type, s->init_buf->data, + s->msg_callback(1, s->version, recordtype, s->init_buf->data, s->init_off + s->init_num, ussl, s->msg_callback_arg); @@ -344,14 +348,11 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) int dtls_get_message(SSL_CONNECTION *s, int *mt) { - struct hm_header_st *msg_hdr; - unsigned char *p; - size_t msg_len; + unsigned char *rec_data; size_t tmplen; int errtype; - msg_hdr = &s->d1->r_msg_hdr; - memset(msg_hdr, 0, sizeof(*msg_hdr)); + s->d1->r_msg_seq = 0; again: if (!dtls_get_reassembled_message(s, &errtype, &tmplen)) { @@ -365,12 +366,12 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt) *mt = s->s3.tmp.message_type; - p = (unsigned char *)s->init_buf->data; + rec_data = (unsigned char *)s->init_buf->data; if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) { if (s->msg_callback) { s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC, - p, 1, SSL_CONNECTION_GET_USER_SSL(s), + rec_data, 1, SSL_CONNECTION_GET_USER_SSL(s), s->msg_callback_arg); } /* @@ -379,16 +380,11 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt) return 1; } - msg_len = msg_hdr->msg_len; - /* reconstruct message header */ - *(p++) = msg_hdr->type; - l2n3(msg_len, p); - s2n(msg_hdr->seq, p); - l2n3(0, p); - l2n3(msg_len, p); + dtls1_write_hm_header(rec_data, s->s3.tmp.message_type, s->s3.tmp.message_size, + s->d1->r_msg_seq, 0, s->s3.tmp.message_size); - memset(msg_hdr, 0, sizeof(*msg_hdr)); + s->d1->r_msg_seq = 0; s->d1->handshake_read_seq++; @@ -452,7 +448,7 @@ static size_t dtls1_max_handshake_message_len(const SSL_CONNECTION *s) } static int dtls1_preprocess_fragment(SSL_CONNECTION *s, - struct hm_header_st *msg_hdr) + const struct hm_header_st * const msg_hdr) { size_t frag_off, frag_len, msg_len; @@ -467,30 +463,19 @@ static int dtls1_preprocess_fragment(SSL_CONNECTION *s, return 0; } - if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */ - /* - * msg_len is limited to 2^24, but is effectively checked against - * dtls_max_handshake_message_len(s) above - */ - if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) { - SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB); - return 0; - } - - s->s3.tmp.message_size = msg_len; - s->d1->r_msg_hdr.msg_len = msg_len; - s->s3.tmp.message_type = msg_hdr->type; - s->d1->r_msg_hdr.type = msg_hdr->type; - s->d1->r_msg_hdr.seq = msg_hdr->seq; - } else if (msg_len != s->d1->r_msg_hdr.msg_len) { - /* - * They must be playing with us! BTW, failure to enforce upper limit - * would open possibility for buffer overrun. - */ - SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_EXCESSIVE_MESSAGE_SIZE); + /* + * msg_len is limited to 2^24, but is effectively checked against + * dtls_max_handshake_message_len(s) above + */ + if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) { + SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB); return 0; } + s->s3.tmp.message_size = msg_len; + s->s3.tmp.message_type = msg_hdr->type; + s->d1->r_msg_seq = msg_hdr->seq; + return 1; } @@ -512,7 +497,7 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len) int ret; int chretran = 0; - iter = pqueue_iterator(s->d1->buffered_messages); + iter = pqueue_iterator(s->d1->rcvd_messages); do { item = pqueue_next(&iter); if (item == NULL) @@ -533,7 +518,7 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len) * It is safe to pop this message from the queue even though * we have an active iterator */ - pqueue_pop(s->d1->buffered_messages); + pqueue_pop(s->d1->rcvd_messages); dtls1_hm_fragment_free(frag); pitem_free(item); item = NULL; @@ -553,7 +538,7 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len) * We have fragments for both a ClientHello without * cookie and one with. Ditch the one without. */ - pqueue_pop(s->d1->buffered_messages); + pqueue_pop(s->d1->rcvd_messages); dtls1_hm_fragment_free(frag); pitem_free(item); item = next; @@ -574,7 +559,7 @@ static int dtls1_retrieve_buffered_fragment(SSL_CONNECTION *s, size_t *len) if (s->d1->handshake_read_seq == frag->msg_header.seq || chretran) { size_t frag_len = frag->msg_header.frag_len; - pqueue_pop(s->d1->buffered_messages); + pqueue_pop(s->d1->rcvd_messages); /* Calls SSLfatal() as required */ ret = dtls1_preprocess_fragment(s, &frag->msg_header); @@ -635,7 +620,7 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s, memset(seq64be, 0, sizeof(seq64be)); seq64be[6] = (unsigned char)(msg_hdr->seq >> 8); seq64be[7] = (unsigned char)msg_hdr->seq; - item = pqueue_find(s->d1->buffered_messages, seq64be); + item = pqueue_find(s->d1->rcvd_messages, seq64be); if (item == NULL) { frag = dtls1_hm_fragment_new(msg_hdr->msg_len, 1); @@ -679,8 +664,6 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s, frag->fragment + msg_hdr->frag_off, frag_len, 0, &readbytes); if (i <= 0 || readbytes != frag_len) - i = -1; - if (i <= 0) goto err; RSMBLY_BITMASK_MARK(frag->reassembly, (long)msg_hdr->frag_off, @@ -691,19 +674,15 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s, RSMBLY_BITMASK_IS_COMPLETE(frag->reassembly, (long)msg_hdr->msg_len, is_complete); - if (is_complete) { - OPENSSL_free(frag->reassembly); + if (is_complete) frag->reassembly = NULL; - } if (item == NULL) { item = pitem_new(seq64be, frag); - if (item == NULL) { - i = -1; + if (item == NULL) goto err; - } - item = pqueue_insert(s->d1->buffered_messages, item); + item = pqueue_insert(s->d1->rcvd_messages, item); /* * pqueue_insert fails iff a duplicate item is inserted. However, * |item| cannot be a duplicate. If it were, |pqueue_find|, above, @@ -740,7 +719,7 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s, memset(seq64be, 0, sizeof(seq64be)); seq64be[6] = (unsigned char)(msg_hdr->seq >> 8); seq64be[7] = (unsigned char)msg_hdr->seq; - item = pqueue_find(s->d1->buffered_messages, seq64be); + item = pqueue_find(s->d1->rcvd_messages, seq64be); /* * If we already have an entry and this one is a fragment, don't discard @@ -790,9 +769,7 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s, i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, NULL, frag->fragment, frag_len, 0, &readbytes); - if (i<=0 || readbytes != frag_len) - i = -1; - if (i <= 0) + if (i <= 0 || readbytes != frag_len) goto err; } @@ -800,7 +777,7 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s, if (item == NULL) goto err; - item = pqueue_insert(s->d1->buffered_messages, item); + item = pqueue_insert(s->d1->rcvd_messages, item); /* * pqueue_insert fails iff a duplicate item is inserted. However, * |item| cannot be a duplicate. If it were, |pqueue_find|, above, @@ -821,10 +798,36 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s, return 0; } +static int dtls1_read_hm_header(unsigned char *msgheaderstart, + struct hm_header_st *msg_hdr) +{ + unsigned long msg_len, frag_off, frag_len; + unsigned int msg_seq, msg_type; + PACKET msgheader; + + if (!PACKET_buf_init(&msgheader, msgheaderstart, DTLS1_HM_HEADER_LENGTH) + || !PACKET_get_1(&msgheader, &msg_type) + || !PACKET_get_net_3(&msgheader, &msg_len) + || !PACKET_get_net_2(&msgheader, &msg_seq) + || !PACKET_get_net_3(&msgheader, &frag_off) + || !PACKET_get_net_3(&msgheader, &frag_len) + || PACKET_remaining(&msgheader) != 0) { + return 0; + } + + /* We just checked that values did not exceed max size so cast must be alright */ + msg_hdr->type = (unsigned char)msg_type; + msg_hdr->msg_len = (size_t)msg_len; + msg_hdr->seq = (unsigned short)msg_seq; + msg_hdr->frag_off = (size_t)frag_off; + msg_hdr->frag_len = (size_t)frag_len; + + return 1; +} + static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, size_t *len) { - size_t mlen, frag_off, frag_len; int i, ret; uint8_t recvd_type; struct hm_header_st msg_hdr; @@ -840,14 +843,14 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, redo: /* see if we have the required fragment already */ - ret = dtls1_retrieve_buffered_fragment(s, &frag_len); + ret = dtls1_retrieve_buffered_fragment(s, &msg_hdr.frag_len); if (ret < 0) { /* SSLfatal() already called */ return 0; } if (ret > 0) { - s->init_num = frag_len; - *len = frag_len; + s->init_num = msg_hdr.frag_len; + *len = msg_hdr.frag_len; return 1; } @@ -881,17 +884,16 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, } /* parse the message fragment header */ - dtls1_get_message_header(p, &msg_hdr); - - mlen = msg_hdr.msg_len; - frag_off = msg_hdr.frag_off; - frag_len = msg_hdr.frag_len; + if (!dtls1_read_hm_header(p, &msg_hdr)) { + SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); + goto f_err; + } /* * We must have at least frag_len bytes left in the record to be read. * Fragments must not span records. */ - if (frag_len > s->rlayer.tlsrecs[s->rlayer.curr_rec].length) { + if (msg_hdr.frag_len > s->rlayer.tlsrecs[s->rlayer.curr_rec].length) { SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); goto f_err; } @@ -906,7 +908,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, if (!s->server || msg_hdr.seq != 0 || s->d1->handshake_read_seq != 1 - || p[0] != SSL3_MT_CLIENT_HELLO + || msg_hdr.type != SSL3_MT_CLIENT_HELLO || s->statem.hand_state != DTLS_ST_SW_HELLO_VERIFY_REQUEST) { *errtype = dtls1_process_out_of_seq_message(s, &msg_hdr); return 0; @@ -919,21 +921,20 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, chretran = 1; } - if (frag_len && frag_len < mlen) { + if (msg_hdr.frag_len > 0 && msg_hdr.frag_len < msg_hdr.msg_len) { *errtype = dtls1_reassemble_fragment(s, &msg_hdr); return 0; } if (!s->server - && s->d1->r_msg_hdr.frag_off == 0 && s->statem.hand_state != TLS_ST_OK - && p[0] == SSL3_MT_HELLO_REQUEST) { + && msg_hdr.type == SSL3_MT_HELLO_REQUEST) { /* * The server may always send 'Hello Request' messages -- we are * doing a handshake anyway now, so ignore them if their format is * correct. Does not count for 'Finished' MAC. */ - if (p[1] == 0 && p[2] == 0 && p[3] == 0) { + if (msg_hdr.msg_len == 0) { if (s->msg_callback) s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, p, DTLS1_HM_HEADER_LENGTH, ussl, @@ -941,8 +942,8 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, s->init_num = 0; goto redo; - } else { /* Incorrectly formatted Hello request */ - + } else { + /* Incorrectly formatted Hello request */ SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE); goto f_err; } @@ -953,11 +954,11 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, goto f_err; } - if (frag_len > 0) { - p += DTLS1_HM_HEADER_LENGTH; + if (msg_hdr.frag_len > 0) { + p += DTLS1_HM_HEADER_LENGTH + msg_hdr.frag_off; i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, NULL, - &p[frag_off], frag_len, 0, &readbytes); + p, msg_hdr.frag_len, 0, &readbytes); /* * This shouldn't ever fail due to NBIO because we already checked @@ -976,7 +977,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, * XDTLS: an incorrectly formatted fragment should cause the handshake * to fail */ - if (readbytes != frag_len) { + if (readbytes != msg_hdr.frag_len) { SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); goto f_err; } @@ -998,7 +999,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, * soon as they sum up to handshake packet length, we assume we have got * all the fragments. */ - *len = s->init_num = frag_len; + *len = s->init_num = msg_hdr.frag_len; return 1; f_err: @@ -1094,7 +1095,7 @@ int dtls1_read_failed(SSL_CONNECTION *s, int code) return dtls1_handle_timeout(s); } -int dtls1_get_queue_priority(unsigned short seq, int is_ccs) +int dtls1_get_queue_priority(unsigned short seq, int record_type) { /* * The index of the retransmission queue actually is the message sequence @@ -1106,36 +1107,37 @@ int dtls1_get_queue_priority(unsigned short seq, int is_ccs) * Finished, it also maintains the order of the index (important for * priority queues) and fits in the unsigned short variable. */ - return seq * 2 - is_ccs; + int lsb = (record_type == SSL3_RT_CHANGE_CIPHER_SPEC); + + return seq * 2 - lsb; } -int dtls1_retransmit_buffered_messages(SSL_CONNECTION *s) +int dtls1_retransmit_sent_messages(SSL_CONNECTION *s) { - pqueue *sent = s->d1->sent_messages; - piterator iter; + piterator iter = pqueue_iterator(s->d1->sent_messages); pitem *item; - hm_fragment *frag; int found = 0; - iter = pqueue_iterator(sent); - for (item = pqueue_next(&iter); item != NULL; item = pqueue_next(&iter)) { - frag = (hm_fragment *)item->data; - if (dtls1_retransmit_message(s, (unsigned short) - dtls1_get_queue_priority - (frag->msg_header.seq, - frag->msg_header.is_ccs), &found) <= 0) + int prio; + dtls_sent_msg *sent_msg = (dtls_sent_msg *)item->data; + + prio = dtls1_get_queue_priority(sent_msg->msg_info.msg_seq, sent_msg->record_type); + + if (dtls1_retransmit_message(s, (unsigned short)prio, &found) <= 0) return -1; } return 1; } -int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs) +int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type) { pitem *item; - hm_fragment *frag; + dtls_sent_msg *sent_msg; unsigned char seq64be[8]; + size_t headerlen; + int prio; /* * this function is called immediately after a message has been @@ -1144,54 +1146,40 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs) if (!ossl_assert(s->init_off == 0)) return 0; - frag = dtls1_hm_fragment_new(s->init_num, 0); - if (frag == NULL) + sent_msg = dtls1_sent_msg_new(s->init_num); + if (sent_msg == NULL) return 0; - memcpy(frag->fragment, s->init_buf->data, s->init_num); + memcpy(sent_msg->msg_buf, s->init_buf->data, s->init_num); - if (is_ccs) { + if (record_type == SSL3_RT_CHANGE_CIPHER_SPEC) /* For DTLS1_BAD_VER the header length is non-standard */ - if (!ossl_assert(s->d1->w_msg_hdr.msg_len + - ((s->version == - DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH) - == (unsigned int)s->init_num)) { - dtls1_hm_fragment_free(frag); - return 0; - } - } else { - if (!ossl_assert(s->d1->w_msg_hdr.msg_len + - DTLS1_HM_HEADER_LENGTH == (unsigned int)s->init_num)) { - dtls1_hm_fragment_free(frag); - return 0; - } + headerlen = (s->version == DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH; + else + headerlen = DTLS1_HM_HEADER_LENGTH; + + if (!ossl_assert(s->d1->w_msg.msg_body_len + headerlen == s->init_num)) { + dtls1_sent_msg_free(sent_msg); + return 0; } - frag->msg_header.msg_len = s->d1->w_msg_hdr.msg_len; - frag->msg_header.seq = s->d1->w_msg_hdr.seq; - frag->msg_header.type = s->d1->w_msg_hdr.type; - frag->msg_header.frag_off = 0; - frag->msg_header.frag_len = s->d1->w_msg_hdr.msg_len; - frag->msg_header.is_ccs = is_ccs; + sent_msg->msg_info.msg_body_len = s->d1->w_msg.msg_body_len; + sent_msg->msg_info.msg_seq = s->d1->w_msg.msg_seq; + sent_msg->msg_info.msg_type = s->d1->w_msg.msg_type; + sent_msg->record_type = record_type; /* save current state */ - frag->msg_header.saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod; - frag->msg_header.saved_retransmit_state.wrl = s->rlayer.wrl; - + sent_msg->saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod; + sent_msg->saved_retransmit_state.wrl = s->rlayer.wrl; + prio = dtls1_get_queue_priority(sent_msg->msg_info.msg_seq, sent_msg->record_type); memset(seq64be, 0, sizeof(seq64be)); - seq64be[6] = - (unsigned - char)(dtls1_get_queue_priority(frag->msg_header.seq, - frag->msg_header.is_ccs) >> 8); - seq64be[7] = - (unsigned - char)(dtls1_get_queue_priority(frag->msg_header.seq, - frag->msg_header.is_ccs)); - - item = pitem_new(seq64be, frag); + seq64be[6] = (unsigned char)(prio >> 8); + seq64be[7] = (unsigned char)prio; + + item = pitem_new(seq64be, sent_msg); if (item == NULL) { - dtls1_hm_fragment_free(frag); + dtls1_sent_msg_free(sent_msg); return 0; } @@ -1204,7 +1192,7 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) int ret; /* XDTLS: for now assuming that read/writes are blocking */ pitem *item; - hm_fragment *frag; + dtls_sent_msg *sent_msg; unsigned long header_length; unsigned char seq64be[8]; struct dtls1_retransmit_state saved_state; @@ -1222,21 +1210,20 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) } *found = 1; - frag = (hm_fragment *)item->data; + sent_msg = (dtls_sent_msg *)item->data; - if (frag->msg_header.is_ccs) + if (sent_msg->record_type == SSL3_RT_CHANGE_CIPHER_SPEC) header_length = DTLS1_CCS_HEADER_LENGTH; else header_length = DTLS1_HM_HEADER_LENGTH; - memcpy(s->init_buf->data, frag->fragment, - frag->msg_header.msg_len + header_length); - s->init_num = frag->msg_header.msg_len + header_length; + memcpy(s->init_buf->data, sent_msg->msg_buf, + sent_msg->msg_info.msg_body_len + header_length); + s->init_num = sent_msg->msg_info.msg_body_len + header_length; - dtls1_set_message_header_int(s, frag->msg_header.type, - frag->msg_header.msg_len, - frag->msg_header.seq, 0, - frag->msg_header.frag_len); + s->d1->w_msg.msg_type = sent_msg->msg_info.msg_type; + s->d1->w_msg.msg_body_len = sent_msg->msg_info.msg_body_len; + s->d1->w_msg.msg_seq = sent_msg->msg_info.msg_seq; /* save current state */ saved_state.wrlmethod = s->rlayer.wrlmethod; @@ -1245,8 +1232,8 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) s->d1->retransmitting = 1; /* restore state in which the message was originally sent */ - s->rlayer.wrlmethod = frag->msg_header.saved_retransmit_state.wrlmethod; - s->rlayer.wrl = frag->msg_header.saved_retransmit_state.wrl; + s->rlayer.wrlmethod = sent_msg->saved_retransmit_state.wrlmethod; + s->rlayer.wrl = sent_msg->saved_retransmit_state.wrl; /* * The old wrl may be still pointing at an old BIO. Update it to what we're @@ -1254,8 +1241,7 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) */ s->rlayer.wrlmethod->set1_bio(s->rlayer.wrl, s->wbio); - ret = dtls1_do_write(s, frag->msg_header.is_ccs ? - SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE); + ret = dtls1_do_write(s, sent_msg->record_type); /* restore current state */ s->rlayer.wrlmethod = saved_state.wrlmethod; @@ -1267,58 +1253,26 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) return ret; } -void dtls1_set_message_header(SSL_CONNECTION *s, - unsigned char mt, size_t len, - size_t frag_off, size_t frag_len) -{ - if (frag_off == 0) { - s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; - s->d1->next_handshake_write_seq++; - } - - dtls1_set_message_header_int(s, mt, len, s->d1->handshake_write_seq, - frag_off, frag_len); -} - -/* don't actually do the writing, wait till the MTU has been retrieved */ -static void -dtls1_set_message_header_int(SSL_CONNECTION *s, unsigned char mt, - size_t len, unsigned short seq_num, - size_t frag_off, size_t frag_len) -{ - struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr; - - msg_hdr->type = mt; - msg_hdr->msg_len = len; - msg_hdr->seq = seq_num; - msg_hdr->frag_off = frag_off; - msg_hdr->frag_len = frag_len; -} - -void dtls1_get_message_header(const unsigned char *data, struct - hm_header_st *msg_hdr) -{ - memset(msg_hdr, 0, sizeof(*msg_hdr)); - msg_hdr->type = *(data++); - n2l3(data, msg_hdr->msg_len); - - n2s(data, msg_hdr->seq); - n2l3(data, msg_hdr->frag_off); - n2l3(data, msg_hdr->frag_len); -} - int dtls1_set_handshake_header(SSL_CONNECTION *s, WPACKET *pkt, int htype) { if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) { s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; - dtls1_set_message_header_int(s, SSL3_MT_CCS, 0, - s->d1->handshake_write_seq, 0, 0); + + s->d1->w_msg.msg_type = SSL3_MT_CCS; + s->d1->w_msg.msg_body_len = 0; + s->d1->w_msg.msg_seq = s->d1->handshake_write_seq; + if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS)) return 0; } else { size_t subpacket_offset = DTLS1_HM_HEADER_LENGTH - SSL3_HM_HEADER_LENGTH; - dtls1_set_message_header(s, htype, 0, 0, 0); + s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; + s->d1->next_handshake_write_seq++; + + s->d1->w_msg.msg_type = htype; + s->d1->w_msg.msg_body_len = 0; + s->d1->w_msg.msg_seq = s->d1->handshake_write_seq; /* Set the content type and 3 bytes for the message len */ if (!WPACKET_put_bytes_u8(pkt, htype) @@ -1336,23 +1290,26 @@ int dtls1_set_handshake_header(SSL_CONNECTION *s, WPACKET *pkt, int htype) int dtls1_close_construct_packet(SSL_CONNECTION *s, WPACKET *pkt, int htype) { size_t msglen; + int record_type; + + /* Convert from possible dummy message type */ + record_type = (htype == SSL3_MT_CHANGE_CIPHER_SPEC) ? SSL3_RT_CHANGE_CIPHER_SPEC + : SSL3_RT_HANDSHAKE; if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt)) || !WPACKET_get_length(pkt, &msglen) || msglen > INT_MAX) return 0; - if (htype != SSL3_MT_CHANGE_CIPHER_SPEC) { - s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH; - s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH; - } + if (htype != SSL3_MT_CHANGE_CIPHER_SPEC) + s->d1->w_msg.msg_body_len = msglen - DTLS1_HM_HEADER_LENGTH; + s->init_num = msglen; s->init_off = 0; if (htype != DTLS1_MT_HELLO_VERIFY_REQUEST) { /* Buffer the message to handle re-xmits */ - if (!dtls1_buffer_message(s, htype == SSL3_MT_CHANGE_CIPHER_SPEC - ? 1 : 0)) + if (!dtls1_buffer_sent_message(s, record_type)) return 0; } diff --git a/util/indent.pro b/util/indent.pro index bc626e4a4bc..dfe7b6a418f 100644 --- a/util/indent.pro +++ b/util/indent.pro @@ -600,6 +600,8 @@ -T clock_t -T custom_ext_methods -T hm_fragment +-T dtls_msg_info +-T dtls_sent_msg -T ssl_ctx_st -T ssl_flag_tbl -T ssl_st