From: Pascal Knecht Date: Mon, 21 Sep 2020 20:19:34 +0000 (+0200) Subject: tls-server: Refactor writing of key share extensions X-Git-Tag: 5.9.2rc1~23^2~49 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5c4cb40e476ea776a94d7bfde825e88ed1c036af;p=thirdparty%2Fstrongswan.git tls-server: Refactor writing of key share extensions Client and server now share the same code to write this extension. --- diff --git a/src/libtls/tls_peer.c b/src/libtls/tls_peer.c index 4e5c2cc039..7d6c1ff7b7 100644 --- a/src/libtls/tls_peer.c +++ b/src/libtls/tls_peer.c @@ -157,6 +157,10 @@ struct private_tls_peer_t { chunk_t cert_types; }; +/* Implemented in tls_server.c */ +bool tls_write_key_share(bio_writer_t **key_share, tls_named_group_t group, + diffie_hellman_t *dh); + /** * Verify the DH group/key type requested by the server is valid. */ @@ -1202,7 +1206,6 @@ static status_t send_client_hello(private_tls_peer_t *this, enumerator_t *enumerator; int count, i, v; rng_t *rng; - chunk_t pub; htoun32(&this->client_random, time(NULL)); rng = lib->crypto->create_rng(lib->crypto, RNG_WEAK); @@ -1352,34 +1355,21 @@ static status_t send_client_hello(private_tls_peer_t *this, extensions->write_data16(extensions, signatures->get_buf(signatures)); signatures->destroy(signatures); - if (this->dh) + if (this->tls->get_version_max(this->tls) >= TLS_1_3 && + this->dh) { DBG2(DBG_TLS, "sending extension: %N", tls_extension_names, TLS_EXT_KEY_SHARE); - if (!this->dh->get_my_public_value(this->dh, &pub)) + extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE); + if (!tls_write_key_share(&key_share, selected_curve, this->dh)) { this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR); extensions->destroy(extensions); return NEED_MORE; } - extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE); - key_share = bio_writer_create(pub.len + 6); - key_share->write_uint16(key_share, selected_curve); - if (selected_curve == TLS_CURVE25519 || - selected_curve == TLS_CURVE448) - { - key_share->write_data16(key_share, pub); - } - else - { /* classic format (see RFC 8446, section 4.2.8.2) */ - key_share->write_uint16(key_share, pub.len + 1); - key_share->write_uint8(key_share, TLS_ANSI_UNCOMPRESSED); - key_share->write_data(key_share, pub); - } key_share->wrap16(key_share); extensions->write_data16(extensions, key_share->get_buf(key_share)); key_share->destroy(key_share); - free(pub.ptr); } writer->write_data16(writer, extensions->get_buf(extensions)); diff --git a/src/libtls/tls_server.c b/src/libtls/tls_server.c index c858252087..fb897cf0ef 100644 --- a/src/libtls/tls_server.c +++ b/src/libtls/tls_server.c @@ -340,8 +340,7 @@ static status_t process_client_hello(private_tls_server_t *this, { DBG1(DBG_TLS, "invalid %N extension", tls_extension_names, extension_type); - this->alert->add(this->alert, TLS_FATAL, - TLS_DECODE_ERROR); + this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR); extensions->destroy(extensions); extension->destroy(extension); return NEED_MORE; @@ -915,15 +914,46 @@ METHOD(tls_handshake_t, process, status_t, return NEED_MORE; } +/** + * Write public key into key share extension + */ +bool tls_write_key_share(bio_writer_t **key_share, tls_named_group_t group, + diffie_hellman_t *dh) +{ + bio_writer_t *writer; + chunk_t pub; + + if (!dh || !dh->get_my_public_value(dh, &pub)) + { + return FALSE; + } + *key_share = writer = bio_writer_create(pub.len + 7); + writer->write_uint16(writer, group); + if (group == TLS_CURVE25519 || + group == TLS_CURVE448) + { + writer->write_data16(writer, pub); + } + else + { /* classic format (see RFC 8446, section 4.2.8.2) */ + writer->write_uint16(writer, pub.len + 1); + writer->write_uint8(writer, TLS_ANSI_UNCOMPRESSED); + writer->write_data(writer, pub); + } + free(pub.ptr); + return TRUE; +} + /** * Send ServerHello message */ static status_t send_server_hello(private_tls_server_t *this, tls_handshake_type_t *type, bio_writer_t *writer) { - bio_writer_t *extensions, *key_share; - tls_version_t version = this->tls->get_version_max(this->tls); - chunk_t pub; + bio_writer_t *key_share, *extensions; + tls_version_t version; + + version = this->tls->get_version_max(this->tls); /* cap legacy version at TLS 1.2 for middlebox compatibility */ writer->write_uint16(writer, min(TLS_1_2, version)); @@ -948,36 +978,18 @@ static status_t send_server_hello(private_tls_server_t *this, extensions->write_uint16(extensions, 2); extensions->write_uint16(extensions, version); - if (this->dh) - { - tls_named_group_t selected_curve = this->requested_curve; + DBG2(DBG_TLS, "sending extension: %N", + tls_extension_names, TLS_EXT_KEY_SHARE); + extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE); - DBG2(DBG_TLS, "sending extension: %N", - tls_extension_names, TLS_EXT_KEY_SHARE); - if (!this->dh->get_my_public_value(this->dh, &pub)) - { - this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR); - extensions->destroy(extensions); - return NEED_MORE; - } - extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE); - key_share = bio_writer_create(pub.len + 6); - key_share->write_uint16(key_share, selected_curve); - if (selected_curve == TLS_CURVE25519 || - selected_curve == TLS_CURVE448) - { - key_share->write_data16(key_share, pub); - } - else - { /* classic format (see RFC 8446, section 4.2.8.2) */ - key_share->write_uint16(key_share, pub.len + 1); - key_share->write_uint8(key_share, TLS_ANSI_UNCOMPRESSED); - key_share->write_data(key_share, pub); - } - extensions->write_data16(extensions, key_share->get_buf(key_share)); - key_share->destroy(key_share); - free(pub.ptr); + if (!tls_write_key_share(&key_share, this->requested_curve, this->dh)) + { + this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR); + extensions->destroy(extensions); + return NEED_MORE; } + extensions->write_data16(extensions, key_share->get_buf(key_share)); + key_share->destroy(key_share); writer->write_data16(writer, extensions->get_buf(extensions)); extensions->destroy(extensions);