]> git.ipfire.org Git - people/ms/strongswan.git/commitdiff
tls-server: Support multiple client key shares
authorPascal Knecht <pascal.knecht@hsr.ch>
Fri, 2 Oct 2020 16:11:45 +0000 (18:11 +0200)
committerTobias Brunner <tobias@strongswan.org>
Fri, 12 Feb 2021 13:35:23 +0000 (14:35 +0100)
A client can send one or multiple key shares from which the server picks
one it supports (checked in its preferred order).  A retry is requested if
none of the key shares are supported.

src/libtls/tls_server.c

index fa90091a128c39b6d2b9496876a5005286972878..3d114d589908f4318056f0094265ea90fab92bb8 100644 (file)
@@ -22,6 +22,7 @@
 
 #include <utils/debug.h>
 #include <credentials/certificates/x509.h>
+#include <collections/array.h>
 
 typedef struct private_tls_server_t private_tls_server_t;
 
@@ -263,6 +264,38 @@ static bool peer_supports_curve(private_tls_server_t *this,
        return FALSE;
 }
 
+/**
+ * TLS 1.3 key exchange key share
+ */
+typedef struct {
+       uint16_t curve;
+       chunk_t key_share;
+} key_share_t;
+
+/**
+ * Check if peer sent a key share of a given TLS named DH group
+ */
+static bool peer_offered_curve(array_t *key_shares, tls_named_group_t curve,
+                                                          key_share_t *out)
+{
+       key_share_t peer;
+       int i;
+
+       for (i = 0; i < array_count(key_shares); i++)
+       {
+               array_get(key_shares, i, &peer);
+               if (curve == peer.curve)
+               {
+                       if (out)
+                       {
+                               *out = peer;
+                       }
+                       return TRUE;
+               }
+       }
+       return FALSE;
+}
+
 /**
  * Check if client is currently retrying to connect to the server.
  */
@@ -277,10 +310,10 @@ static bool retrying(private_tls_server_t *this)
 static status_t process_client_hello(private_tls_server_t *this,
                                                                         bio_reader_t *reader)
 {
-       uint16_t legacy_version = 0, version = 0, key_share_length, key_type = 0;
-       uint16_t extension_type = 0;
+       uint16_t legacy_version = 0, version = 0, extension_type = 0;
        chunk_t random, session, ciphers, versions = chunk_empty, compression;
-       chunk_t ext = chunk_empty, key_share = chunk_empty;
+       chunk_t ext = chunk_empty, key_shares = chunk_empty;
+       key_share_t peer = {0};
        chunk_t extension_data = chunk_empty;
        bio_reader_t *extensions, *extension;
        tls_cipher_suite_t *suites;
@@ -345,10 +378,7 @@ static status_t process_client_hello(private_tls_server_t *this,
                                }
                                break;
                        case TLS_EXT_KEY_SHARE:
-                               if (!extension->read_uint16(extension, &key_share_length) ||
-                                       !extension->read_uint16(extension, &key_type) ||
-                                       !extension->read_data16(extension, &key_share) ||
-                                       !key_share.len)
+                               if (!extension->read_data16(extension, &key_shares))
                                {
                                        DBG1(DBG_TLS, "invalid %N extension",
                                                 tls_extension_names, extension_type);
@@ -470,17 +500,38 @@ static status_t process_client_hello(private_tls_server_t *this,
                tls_named_group_t curve, requesting_curve = 0;
                enumerator_t *enumerator;
                chunk_t shared_secret = chunk_empty;
+               array_t *peer_key_shares;
+
+               peer_key_shares = array_create(sizeof(key_share_t), 1);
+               extension = bio_reader_create(key_shares);
+               while (extension->remaining(extension))
+               {
+                       if (!extension->read_uint16(extension, &peer.curve) ||
+                               !extension->read_data16(extension, &peer.key_share) ||
+                               !peer.key_share.len)
+                       {
+                               DBG1(DBG_TLS, "invalid %N extension",
+                                        tls_extension_names, extension_type);
+                               this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
+                               extension->destroy(extension);
+                               array_destroy(peer_key_shares);
+                               return NEED_MORE;
+                       }
+                       array_insert(peer_key_shares, ARRAY_TAIL, &peer);
+               }
+               extension->destroy(extension);
 
                enumerator = this->crypto->create_ec_enumerator(this->crypto);
                while (enumerator->enumerate(enumerator, &group, &curve))
                {
                        if (!requesting_curve &&
-                               curve != key_type &&
-                               peer_supports_curve(this, curve))
+                               peer_supports_curve(this, curve) &&
+                               !peer_offered_curve(peer_key_shares, curve, NULL))
                        {
                                requesting_curve = curve;
                        }
-                       if (curve == key_type && peer_supports_curve(this, curve))
+                       if (peer_supports_curve(this, curve) &&
+                               peer_offered_curve(peer_key_shares, curve, &peer))
                        {
                                DBG1(DBG_TLS, "using key exchange %N",
                                         tls_named_group_names, curve);
@@ -489,6 +540,7 @@ static status_t process_client_hello(private_tls_server_t *this,
                        }
                }
                enumerator->destroy(enumerator);
+               array_destroy(peer_key_shares);
 
                if (!this->dh)
                {
@@ -499,9 +551,6 @@ static status_t process_client_hello(private_tls_server_t *this,
                                return NEED_MORE;
                        }
 
-                       /* TODO: process all client offered key shares, currently only the
-                        * first key share extensions is processed other offered key shares
-                        * are ignored. */
                        if (!requesting_curve)
                        {
                                DBG1(DBG_TLS, "no mutual supported group in client hello");
@@ -519,21 +568,21 @@ static status_t process_client_hello(private_tls_server_t *this,
                }
                else
                {
-                       if (key_share.len &&
-                               key_type != TLS_CURVE25519 &&
-                               key_type != TLS_CURVE448)
+                       if (peer.key_share.len &&
+                               peer.curve != TLS_CURVE25519 &&
+                               peer.curve != TLS_CURVE448)
                        {       /* classic format (see RFC 8446, section 4.2.8.2) */
-                               if (key_share.ptr[0] != TLS_ANSI_UNCOMPRESSED)
+                               if (peer.key_share.ptr[0] != TLS_ANSI_UNCOMPRESSED)
                                {
                                        DBG1(DBG_TLS, "DH point format '%N' not supported",
-                                                tls_ansi_point_format_names, key_share.ptr[0]);
+                                                tls_ansi_point_format_names, peer.key_share.ptr[0]);
                                        this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR);
                                        return NEED_MORE;
                                }
-                               key_share = chunk_skip(key_share, 1);
+                               peer.key_share = chunk_skip(peer.key_share, 1);
                        }
-                       if (!key_share.len ||
-                               !this->dh->set_other_public_value(this->dh, key_share))
+                       if (!peer.key_share.len ||
+                               !this->dh->set_other_public_value(this->dh, peer.key_share))
                        {
                                DBG1(DBG_TLS, "DH key derivation failed");
                                this->alert->add(this->alert, TLS_FATAL, TLS_HANDSHAKE_FAILURE);