]> git.ipfire.org Git - thirdparty/haproxy.git/commitdiff
MEDIUM: ssl: Add ktls support for AWS-LC.
authorOlivier Houchard <ohouchard@haproxy.com>
Thu, 19 Jun 2025 16:44:22 +0000 (18:44 +0200)
committerOlivier Houchard <cognet@ci0.org>
Wed, 20 Aug 2025 16:33:11 +0000 (18:33 +0200)
Add ktls support for AWS-LC. As it does not know anything
about ktls, it means extracting keys from the ssl lib, and provide them
to the kernel. At which point we can use regular recvmsg()/sendmsg()
calls.
This patch only provides support for TLS 1.2, AWS-LC provides a
different way to extract keys for TLS 1.3.
Note that this may work with BoringSSL too, but it has not been tested.

include/haproxy/openssl-compat.h
src/ssl_sock.c

index 0c03cd9913c3cc5140ad9acc858fa1e961a234eb..ffced37cdbbcb9760b1c0e6923815deb84cd906c 100644 (file)
@@ -578,6 +578,11 @@ static inline unsigned long ERR_peek_error_func(const char **func)
 
 #endif /* HAVE_VANILLA_OPENSSL && OPENSSL_VERSION_NUMBER >= 0x3000000fL */
 
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+#include <openssl/hkdf.h>
+#define HA_USE_KTLS
+#endif /* OPENSSL_IS_BORINGSSL || OPENSSL_IS_AWSLC */
+
 #endif /* USE_KTLS */
 
 #endif /* _HAPROXY_OPENSSL_COMPAT_H */
index ee2ab74baf3838c64ef31996514b137e13115bab..1a7473cb26c38764d16e1eb7d3177bbda8a0a210 100644 (file)
@@ -6031,6 +6031,196 @@ static int ssl_remove_xprt(struct connection *conn, void *xprt_ctx, void *toremo
        return (ctx->xprt->remove_xprt(conn, ctx->xprt_ctx, toremove_ctx, newops, newctx));
 }
 
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
+static void ssl_sock_setup_ktls(struct ssl_sock_ctx *ctx)
+{
+       struct kinfo {
+               struct tls_crypto_info info;
+               /*
+                * Should be enough for key + iv + salt + seq for
+                * every cipher.
+                */
+               unsigned char buf[68];
+       } info;
+       struct {
+               int nid;
+               int tls_cipher;
+               int key_size;
+               int salt_size;
+               int iv_size;
+               int seq_size;
+       } known_ciphers[] = {
+#ifdef TLS_CIPHER_AES_GCM_128
+               { NID_aes_128_gcm, TLS_CIPHER_AES_GCM_128, TLS_CIPHER_AES_GCM_128_KEY_SIZE, TLS_CIPHER_AES_GCM_128_SALT_SIZE, TLS_CIPHER_AES_GCM_128_IV_SIZE, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE },
+#endif
+#ifdef TLS_CIPHER_AES_GCM_256
+               { NID_aes_256_gcm, TLS_CIPHER_AES_GCM_256, TLS_CIPHER_AES_GCM_256_KEY_SIZE, TLS_CIPHER_AES_GCM_256_SALT_SIZE, TLS_CIPHER_AES_GCM_256_IV_SIZE, TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE },
+#endif
+#ifdef TLS_CIPHER_AES_CCM_128
+               { NID_aes_128_ccm, TLS_CIPHER_AES_CCM_128, TLS_CIPHER_AES_CCM_128_KEY_SIZE, TLS_CIPHER_AES_CCM_128_SALT_SIZE, TLS_CIPHER_AES_CCM_128_IV_SIZE, TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE },
+#endif
+#ifdef TLS_CIPHER_CHACHA20_POLY1305
+               { NID_chacha20_poly1305, TLS_CIPHER_CHACHA20_POLY1305, TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE, TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE, TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE },
+#endif
+#if defined(TLS_CIPHER_SM4_GCM) && defined(NID_sm4_gcm)
+               { NID_sm4_gcm, TLS_CIPHER_SM4_GCM, TLS_CIPHER_SM4_GCM_KEY_SIZE,cTLS_CIPHER_SM4_GCM_SALT_SIZE, TLS_CIPHER_SM4_GCM_IV_SIZE, TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE },
+#endif
+#if defined(TLS_CIPHER_SM4_CCM) && defined(NID_sm4_ccm)
+               { NID_sm4_ccm, TLS_CIPHER_SM4_CCM, TLS_CIPHER_SM4_CCM_KEY_SIZE,cTLS_CIPHER_SM4_CCM_SALT_SIZE, TLS_CIPHER_SM4_CCM_IV_SIZE, TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE },
+#endif
+#if defined(TLS_CIPHER_ARIA_GCM_128) && defined(NID_aria_128_gcm)
+               { NID_aria_128_gcm, TLS_CIPHER_ARIA_GCM_128, TLS_CIPHER_ARIA_GCM_128_KEY_SIZE, TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, TLS_CIPHER_ARIA_GCM_128_IV_SIZE, TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE },
+#endif
+#if defined(TLS_CIPHER_ARIA_GCM_256) && defined(NID_aria_256_gcm)
+               { NID_aria_256_gcm, TLS_CIPHER_ARIA_GCM_256, TLS_CIPHER_ARIA_GCM_256_KEY_SIZE, TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, TLS_CIPHER_ARIA_GCM_256_IV_SIZE, TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE },
+#endif
+
+       };
+
+       SSL *ssl = ctx->ssl;
+       unsigned char buf[128];
+       uint64_t seq;
+       int info_size;
+       int key_size, salt_size, iv_size, seq_size;
+       int is_tls_12;
+       int nid, i;
+
+       if (!(ctx->flags & SSL_SOCK_F_KTLS_ENABLED))
+               return;
+
+       switch (SSL_version(ctx->ssl)) {
+               case TLS_1_2_VERSION:
+                       is_tls_12 = 1;
+                       break;
+               default:
+                       ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+                       return;
+       }
+
+       nid = SSL_CIPHER_get_cipher_nid(SSL_get_current_cipher(ssl));
+
+       for (i = 0; i < sizeof(known_ciphers) / sizeof(known_ciphers[0]); i++) {
+               if (known_ciphers[i].nid == nid)
+                       break;
+       }
+       if (i == sizeof(known_ciphers) / sizeof(known_ciphers[0])) {
+               ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+               return;
+       }
+
+       key_size = known_ciphers[i].key_size;
+       salt_size = known_ciphers[i].salt_size;
+       iv_size = known_ciphers[i].iv_size;
+       seq_size = known_ciphers[i].seq_size;
+
+       info_size = sizeof(struct tls_crypto_info) + key_size + salt_size + iv_size + seq_size;
+
+       /*
+        * If new ciphers are added, wy may have to increase the buffer size
+        */
+       BUG_ON(key_size + salt_size + iv_size + seq_size > sizeof(info.buf));
+
+       info.info.version = is_tls_12 ? TLS_1_2_VERSION : TLS_1_3_VERSION;
+       info.info.cipher_type = known_ciphers[i].tls_cipher;
+
+       if (is_tls_12) {
+               unsigned char iv[iv_size];
+               int block_key_size = 2 * key_size + 2 * salt_size;
+               int i;
+
+               /*
+                * We may have to increase buf size if new ciphers are
+                * added with bigger key/salt.
+                */
+               BUG_ON(block_key_size > sizeof(buf));
+
+               if (SSL_get_key_block_len(ssl) != block_key_size) {
+                       ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+                       goto out;
+               }
+
+               if (SSL_generate_key_block(ssl, buf, block_key_size) != 1) {
+                       ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+                       goto out;
+               }
+
+               /*
+                * The key block contains :
+                * - client key
+                *   server key
+                *   client salt
+                *   server salt
+                */
+               /*
+                * First, prepare the RX side
+                * The oldest linux versions do not support RTX, that way
+                * we will fail before setting the TX side.
+                */
+               seq = SSL_get_read_sequence(ssl);
+               seq = my_htonll(seq);
+               for (i = 0; i < iv_size; i++)
+                       iv[i] = (unsigned char)statistical_prng_range(256);
+               /* IV */
+               memcpy(&info.buf[0], &iv, iv_size);
+
+               if (!conn_is_back(ctx->conn)) {
+                       /* Key */
+                       memcpy(&info.buf[iv_size], &buf[0], key_size);
+                       /* Salt */
+                       memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size], salt_size);
+               } else {
+                       /* Key */
+                       memcpy(&info.buf[iv_size], &buf[key_size], key_size);
+                       /* Salt */
+                       memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size + salt_size], salt_size);
+               }
+               /* Record seq number */
+               memcpy(&info.buf[iv_size + key_size + salt_size], &seq, seq_size);
+               if (ktls_set_key(ctx, &info, info_size, 0) != 0) {
+                       ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+                       goto out;
+               }
+               /*
+                * Now do the TX side
+                */
+               seq = SSL_get_write_sequence(ssl);
+               seq = my_htonll(seq);
+               for (i = 0; i < iv_size; i++)
+                       iv[i] = (unsigned char)statistical_prng_range(256);
+               memcpy(&info.buf[0], &iv, iv_size);
+               if (!conn_is_back(ctx->conn)) {
+                       /* Key */
+                       memcpy(&info.buf[iv_size], &buf[key_size], key_size);
+                       /* Salt */
+                       memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size + salt_size], salt_size);
+               } else {
+                       /* Key */
+                       memcpy(&info.buf[iv_size], &buf[0], key_size);
+                       /* Salt */
+                       memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size], salt_size);
+               }
+               memcpy(&info.buf[iv_size + key_size + salt_size], &seq, seq_size);
+               if (ktls_set_key(ctx, &info, info_size, 1) != 0) {
+                       /*
+                        * Not much we can do at this point. TLS has been
+                        * enabled for RX, we can't disable it, we won't
+                        * try to support only one side, so give up with
+                        * that connection.
+                        */
+                       ctx->conn->flags |= CO_FL_ERROR;
+                       ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
+                       goto out;
+               }
+               ctx->flags |= SSL_SOCK_F_KTLS_SEND | SSL_SOCK_F_KTLS_RECV;
+       }
+out:
+       return;
+}
+
+#endif
+#endif
+
 struct task *ssl_sock_io_cb(struct task *t, void *context, unsigned int state)
 {
        struct tasklet *tl = (struct tasklet *)t;
@@ -6069,6 +6259,12 @@ struct task *ssl_sock_io_cb(struct task *t, void *context, unsigned int state)
                if (!(ctx->conn->flags & CO_FL_SSL_WAIT_HS)) {
                        /* handshake completed, leave the bulk queue */
                        _HA_ATOMIC_AND(&tl->state, ~TASK_HEAVY);
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
+                       if (ctx->flags & SSL_SOCK_F_KTLS_ENABLED)
+                               ssl_sock_setup_ktls(ctx);
+#endif
+#endif
                }
        }
        /* If we had an error, or the handshake is done and I/O is available,
@@ -6184,7 +6380,6 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
         * EINTR too.
         */
        while (count > 0) {
-
                try = b_contig_space(buf);
                if (!try)
                        break;
@@ -6192,6 +6387,14 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
                if (try > count)
                        try = count;
 
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+               if (ctx->flags & SSL_SOCK_F_KTLS_RECV) {
+                       ret = ctx->xprt->rcv_buf(ctx->conn, ctx->xprt_ctx, buf, try, NULL, NULL, 0);
+
+               } else
+#endif
+#endif
                ret = SSL_read(ctx->ssl, b_tail(buf), try);
 
                if (conn->flags & CO_FL_ERROR) {
@@ -6199,12 +6402,33 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
                        goto out_error;
                }
                if (ret > 0) {
-                       b_add(buf, ret);
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+                       /*
+                        * The next xprt already adjusted the buffer,
+                        * so we should not do it.
+                        */
+                       if (!(ctx->flags & SSL_SOCK_F_KTLS_RECV))
+#endif
+#endif
+                               b_add(buf, ret);
                        done += ret;
                        count -= ret;
                        TRACE_DEVEL("Post SSL_read success", SSL_EV_CONN_RECV, conn, &ret);
                }
                else {
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+                       if (ctx->flags & SSL_SOCK_F_KTLS_RECV)
+                               /*
+                                * At this point the underlying xprt already
+                                * set any connection error, and we can't
+                                * ask the SSL lib, so we can stop now.
+                                */
+                               break;
+                       else
+#endif
+#endif
                        ret =  SSL_get_error(ctx->ssl, ret);
                        if (ret == SSL_ERROR_WANT_WRITE) {
                                /* handshake is running, and it needs to enable write */
@@ -6317,6 +6541,12 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
         * in which case we accept to do it once again.
         */
        while (count) {
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+               int ktls_error = 0;
+#endif
+#endif
+
 #ifdef SSL_READ_EARLY_DATA_SUCCESS
                size_t written_data;
 #endif
@@ -6379,10 +6609,36 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
                                TRACE_PROTO("Write early data", SSL_EV_CONN_SEND|SSL_EV_CONN_SEND_EARLY, conn, &ret);
                        }
 
-               } else
+               } else {
+#endif
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+                       if (ctx->flags & SSL_SOCK_F_KTLS_SEND) {
+                               struct buffer tmpbuf;
+
+                               tmpbuf.size = b_data(buf) - done;
+                               tmpbuf.data = tmpbuf.size;
+                               tmpbuf.area = b_peek(buf, done);
+                               tmpbuf.head = 0;
+                               ret = ctx->xprt->snd_buf(ctx->conn, ctx->xprt_ctx, &tmpbuf, try, NULL, 0, (ctx->xprt_st & SSL_SOCK_SEND_MORE) ? CO_SFL_MSG_MORE : 0);
+                               if (ret < try) {
+                                       if (errno == EINTR)
+                                               continue;
+                                       else if (!(conn->flags & CO_FL_ERROR))
+                                               ktls_error = SSL_ERROR_WANT_WRITE;
+                                       else {
+                                               ktls_error = SSL_ERROR_SSL;
+                                       }
+                               }
+
+                       } else
 #endif
-                       ret = SSL_write(ctx->ssl, b_peek(buf, done), try);
+#endif
+                               ret = SSL_write(ctx->ssl, b_peek(buf, done), try);
 
+#ifdef SSL_READ_EARLY_DATA_SUCCESS
+               }
+#endif
                if (conn->flags & CO_FL_ERROR) {
                        /* CO_FL_ERROR may be set by ssl_sock_infocbk */
                        goto out_error;
@@ -6396,6 +6652,13 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
                        TRACE_DEVEL("Post SSL_write success", SSL_EV_CONN_SEND, conn, &ret);
                }
                else {
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
+                       if (ctx->flags & SSL_SOCK_F_KTLS_SEND)
+                               ret = ktls_error;
+                       else
+#endif
+#endif
                        ret = SSL_get_error(ctx->ssl, ret);
 
                        if (ret == SSL_ERROR_WANT_WRITE) {
@@ -6542,6 +6805,12 @@ static void ssl_sock_shutw(struct connection *conn, void *xprt_ctx, int clean)
 
        TRACE_ENTER(SSL_EV_CONN_END, conn);
 
+#ifdef HA_USE_KTLS
+#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
+       if (ctx->flags & (SSL_SOCK_F_KTLS_RECV | SSL_SOCK_F_KTLS_SEND))
+               return;
+#endif
+#endif
        if (conn->flags & (CO_FL_WAIT_XPRT | CO_FL_SSL_WAIT_HS))
                return;
        conn_report_term_evt(conn, tevt_loc_xprt, xprt_tevt_type_shutw);