]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
KTLS: API ktls_api
authorFrantisek Krenzelok <krenzelok.frantisek@gmail.com>
Fri, 15 Oct 2021 13:00:17 +0000 (15:00 +0200)
committerFrantisek Krenzelok <krenzelok.frantisek@gmail.com>
Tue, 30 Nov 2021 11:27:43 +0000 (12:27 +0100)
ktls is enabled by default, we can check if inicialization was
succesfull with gnutls_transport_is_ktls_enabled

Signed-off-by: Frantisek Krenzelok <krenzelok.frantisek@gmail.com>
13 files changed:
devel/libgnutls.abignore
devel/symbols.last
doc/Makefile.am
doc/manpages/Makefile.am
lib/alert.c
lib/gnutls_int.h
lib/handshake.c
lib/includes/gnutls/socket.h
lib/libgnutls.map
lib/record.c
lib/system/ktls.c
lib/system/ktls.h
tests/gnutls_ktls.c

index 8afa94148ab88a0df91b4b88455ef5b8b49807a3..21d58a529983ba8aa99e018e6f87586aefb4e3e0 100644 (file)
@@ -71,3 +71,6 @@ name = gnutls_sign_set_secure
 
 [suppress_function]
 name = gnutls_sign_set_secure_for_certs
+
+[suppress_function]
+name = gnutls_transport_is_ktls_enabled
index 7bef6663505337d7dcb7cc5a881839aab1dde1a7..b17310dac38046a82e84cd7d756f0fbdb4f3a59a 100644 (file)
@@ -881,6 +881,7 @@ gnutls_transport_get_int2@GNUTLS_3_4
 gnutls_transport_get_int@GNUTLS_3_4
 gnutls_transport_get_ptr2@GNUTLS_3_4
 gnutls_transport_get_ptr@GNUTLS_3_4
+gnutls_transport_is_ktls_enabled@GNUTLS_3_7_2
 gnutls_transport_set_errno@GNUTLS_3_4
 gnutls_transport_set_errno_function@GNUTLS_3_4
 gnutls_transport_set_fastopen@GNUTLS_3_4
index 4f25bf0d5ea8c303e1ffc733ac1f0aa1ae06cfd1..03d73b81ccba3c3643fdd49784d2769cfb4a6d48 100644 (file)
@@ -2146,6 +2146,8 @@ FUNCS += functions/gnutls_transport_get_ptr
 FUNCS += functions/gnutls_transport_get_ptr.short
 FUNCS += functions/gnutls_transport_get_ptr2
 FUNCS += functions/gnutls_transport_get_ptr2.short
+FUNCS += functions/gnutls_transport_is_ktls_enabled
+FUNCS += functions/gnutls_transport_is_ktls_enabled.short
 FUNCS += functions/gnutls_transport_set_errno
 FUNCS += functions/gnutls_transport_set_errno.short
 FUNCS += functions/gnutls_transport_set_errno_function
index 4f39adf0ccbd6d31204ebac030510f33a48638c8..b520bd1c68c3f03e7c36bd80533c893d9d5bd2e8 100644 (file)
@@ -875,6 +875,7 @@ APIMANS += gnutls_transport_get_int.3
 APIMANS += gnutls_transport_get_int2.3
 APIMANS += gnutls_transport_get_ptr.3
 APIMANS += gnutls_transport_get_ptr2.3
+APIMANS += gnutls_transport_is_ktls_enabled.3
 APIMANS += gnutls_transport_set_errno.3
 APIMANS += gnutls_transport_set_errno_function.3
 APIMANS += gnutls_transport_set_fastopen.3
index eda931a1c5783358f402862d76d8ec90670e321c..28ee91b13f2c4ac9dae76db0597e53635db42788 100644 (file)
@@ -182,7 +182,7 @@ gnutls_alert_send(gnutls_session_t session, gnutls_alert_level_t level,
                return ret;
        }
 
-       if (IS_KTLS_ENABLED(session)) {
+       if (IS_KTLS_ENABLED(session, KTLS_SEND)) {
                ret =
                        _gnutls_ktls_send_control_msg(session, GNUTLS_ALERT, data, 2);
        } else {
index 1dbe40485723a4b9428f311e5ea9b7c8e6637690..a660828a57e95f4dee6cb401d333c3817153f2c5 100644 (file)
@@ -176,7 +176,7 @@ typedef enum record_send_state_t {
 #define IS_DTLS(session) (session->internals.transport == GNUTLS_DGRAM)
 
 /* To check whether we have a KTLS enabled */
-#define IS_KTLS_ENABLED(session) (session->internals.ktls_enabled)
+#define IS_KTLS_ENABLED(session, interface) (session->internals.ktls_enabled & interface)
 
 /* the maximum size of encrypted packets */
 #define DEFAULT_MAX_RECORD_SIZE 16384
@@ -1495,10 +1495,7 @@ typedef struct {
        void *epoch_lock;
 
        /* indicates whether or not was KTLS initialized properly. */
-       bool ktls_enabled;
-       int recv_fd;
-       int send_fd;
-
+       int ktls_enabled;
        /* If you add anything here, check _gnutls_handshake_internal_state_clear().
         */
 } internals_st;
index 9d36446e54d0f9db65279ebd1eee5f5cadbd978b..4ddfa66afecb7b0157817b8c5be556d5e1e31901 100644 (file)
@@ -2811,10 +2811,8 @@ int gnutls_handshake(gnutls_session_t session)
        int ret;
 
 #ifdef ENABLE_KTLS
-       int sockin, sockout;
-       gnutls_transport_get_int2(session, &sockin, &sockout);
-       _gnutls_ktls_enable(session, sockin, sockout);  
-#endif 
+       _gnutls_ktls_enable(session);
+#endif
 
        if (unlikely(session->internals.initial_negotiation_completed)) {
                if (vers->tls13_sem) {
@@ -2912,10 +2910,8 @@ int gnutls_handshake(gnutls_session_t session)
        }
 
 #ifdef ENABLE_KTLS
-       if (IS_KTLS_ENABLED(session)) {
-               ret = _gnutls_ktls_set_keys(session);
-               if (ret < 0)
-                       return ret;
+       if (IS_KTLS_ENABLED(session, KTLS_DUPLEX)) {
+               _gnutls_ktls_set_keys(session);
        }
 #endif
 
index 64eb19f896e7b9fc556eb695acacf640b2ccdd73..82f8d2f094f6dda8cdd8dd0fe42a82110a29c0e5 100644 (file)
@@ -43,6 +43,8 @@ void gnutls_transport_set_fastopen(gnutls_session_t session,
                                    socklen_t connect_addrlen,
                                    unsigned int flags);
 
+int gnutls_transport_is_ktls_enabled(gnutls_session_t session);
+
 /* *INDENT-OFF* */
 #ifdef __cplusplus
 }
index dc50c6dba96692749c848ce2780acc8a53eaf99e..109837a5b5517b73f88be5356cb4d6bab48787ff 100644 (file)
@@ -1363,6 +1363,7 @@ GNUTLS_3_7_3
        gnutls_sign_set_secure_for_certs;
        gnutls_digest_set_secure;
        gnutls_protocol_set_enabled;
+       gnutls_transport_is_ktls_enabled;
  local:
        *;
 } GNUTLS_3_7_2;
index ebc07d9e1cfcb92f91bf70d971bb55fcb56be045..d7f8724352a6b89a2ca73fde643045e6db8b912e 100644 (file)
@@ -289,7 +289,7 @@ int gnutls_bye(gnutls_session_t session, gnutls_close_request_t how)
 
        switch (BYE_STATE) {
        case BYE_STATE0:
-               if (!IS_KTLS_ENABLED(session))
+               if (!IS_KTLS_ENABLED(session, KTLS_SEND))
                        ret = _gnutls_io_write_flush(session);
                BYE_STATE = BYE_STATE0;
                if (ret < 0) {
@@ -309,7 +309,7 @@ int gnutls_bye(gnutls_session_t session, gnutls_close_request_t how)
        case BYE_STATE2:
                BYE_STATE = BYE_STATE2;
                if (how == GNUTLS_SHUT_RDWR) {
-                       if (IS_KTLS_ENABLED(session)){
+                       if (IS_KTLS_ENABLED(session, KTLS_SEND)){
                                do {
                                        ret = _gnutls_ktls_recv_int(session,
                                                        GNUTLS_ALERT, NULL, 0);
@@ -2035,7 +2035,7 @@ gnutls_record_send2(gnutls_session_t session, const void *data,
 
        switch(session->internals.rsend_state) {
                case RECORD_SEND_NORMAL:
-                       if (IS_KTLS_ENABLED(session)) {
+                       if (IS_KTLS_ENABLED(session, KTLS_SEND)) {
                                return _gnutls_ktls_send(session, data, data_size);
                        } else {
                                return _gnutls_send_tlen_int(session, GNUTLS_APPLICATION_DATA,
@@ -2306,7 +2306,7 @@ gnutls_record_recv(gnutls_session_t session, void *data, size_t data_size)
                        return gnutls_assert_val(GNUTLS_E_UNAVAILABLE_DURING_HANDSHAKE);
        }
 
-       if (IS_KTLS_ENABLED(session)) {
+       if (IS_KTLS_ENABLED(session, KTLS_RECV)) {
                return _gnutls_ktls_recv(session, data, data_size);
        } else {
                return _gnutls_recv_int(session, GNUTLS_APPLICATION_DATA,
index 7ab1d3215d0b9eb748dfcf5033497039a20a7b45..c54653b49870f219c13ce5288f8fa74756037c46 100644 (file)
 #include "ext/session_ticket.h"
 
 /**
- * gnutls_transport_set_ktls:
+ * gnutls_transport_is_ktls_enabled:
  * @session: is a #gnutls_session_t type.
- * @sockin: is a socket descriptor.
- * @sockout: is a socket descriptor.
  *
- * Enables Kernel TLS for the @session
- * Requieres `tls` kernel module and
- * gnutls configuration with `--enable-ktls`
+ * Checks if KTLS is now enabled and was properly inicialized.
  *
- * Returns: 0 on success error otherwise
+ * Returns: 1 for enabled, 0 otherwise
  *
  * Since: 3.7.2
  **/
-int _gnutls_ktls_enable(gnutls_session_t session, int sockin, int sockout)
+int gnutls_transport_is_ktls_enabled(gnutls_session_t session){
+       if (unlikely(!session->internals.initial_negotiation_completed))
+               return gnutls_assert_val(GNUTLS_E_UNAVAILABLE_DURING_HANDSHAKE);
+
+       return session->internals.ktls_enabled;
+}
+
+int _gnutls_ktls_enable(gnutls_session_t session)
 {
-       if (setsockopt(sockin, SOL_TCP, TCP_ULP, "tls", sizeof ("tls")) < 0)
-               return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+       int sockin, sockout;
+       session->internals.ktls_enabled = 0;
+       gnutls_transport_get_int2(session, &sockin, &sockout);
 
-       session->internals.recv_fd = sockin;
-       session->internals.send_fd = sockin;
+       if (setsockopt(sockin, SOL_TCP, TCP_ULP, "tls", sizeof ("tls")) == 0)
+               session->internals.ktls_enabled |= KTLS_RECV;
 
-       if (sockin != sockout){
-               if (setsockopt(sockout, SOL_TCP, TCP_ULP, "tls", sizeof ("tls")) < 0)
-                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
-               session->internals.send_fd = sockout;
-       }
+       if (sockin != sockout) {
+               if (setsockopt(sockout, SOL_TCP, TCP_ULP, "tls", sizeof ("tls")) == 0)
+                       session->internals.ktls_enabled |= KTLS_SEND;
+       } else
+               session->internals.ktls_enabled |= KTLS_SEND;
 
-       session->internals.ktls_enabled = 1;
        return 0;
 }
 
@@ -72,9 +75,10 @@ int _gnutls_ktls_set_keys(gnutls_session_t session)
        gnutls_datum_t iv;
        gnutls_datum_t cipher_key;
        unsigned char seq_number[8];
+       int sockin, sockout;
        int ret;
 
-       session->internals.ktls_enabled = 0;
+       gnutls_transport_get_int2(session, &sockin, &sockout);
 
        /* check whether or not cipher suite supports ktls
         */
@@ -85,164 +89,174 @@ int _gnutls_ktls_set_keys(gnutls_session_t session)
                return  GNUTLS_E_UNIMPLEMENTED_FEATURE;
        }
 
-       version = (version == GNUTLS_TLS1_2) ? TLS_1_2_VERSION : TLS_1_3_VERSION;
-
        ret = gnutls_record_get_state(session, 1, &mac_key, &iv, &cipher_key,
                                                                   seq_number);
        if (ret < 0) {
                return ret;
        }
 
-       switch (cipher) {
-               case GNUTLS_CIPHER_AES_128_GCM:
-               {
-                       struct tls12_crypto_info_aes_gcm_128 crypto_info;
-
-                       crypto_info.info.version = version;
-                       crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
-
-                       assert(cipher_key.size == TLS_CIPHER_AES_GCM_128_KEY_SIZE);
-
-                       /* for TLS 1.2 IV is generated in kernel */
-                       if (version == TLS_1_2_VERSION) {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-                       } else {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE
-                                               + TLS_CIPHER_AES_GCM_128_IV_SIZE);
-
-                               memcpy(crypto_info.iv, iv.data +
-                                       TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-                                       TLS_CIPHER_AES_GCM_128_IV_SIZE);
+       if(session->internals.ktls_enabled & KTLS_RECV){
+               switch (cipher) {
+                       case GNUTLS_CIPHER_AES_128_GCM:
+                       {
+                               struct tls12_crypto_info_aes_gcm_128 crypto_info;
+                               memset(&crypto_info, 0, sizeof(crypto_info));
+
+                               crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
+                               assert(cipher_key.size == TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+
+                               /* for TLS 1.2 IV is generated in kernel */
+                               if (version == GNUTLS_TLS1_2) {
+                                       crypto_info.info.version = TLS_1_2_VERSION;
+                                       memcpy(crypto_info.iv, seq_number, TLS_CIPHER_AES_GCM_128_IV_SIZE);
+                               } else {
+                                       crypto_info.info.version = TLS_1_3_VERSION;
+                                       assert(iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE
+                                                       + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+                                       memcpy(crypto_info.iv, iv.data +
+                                               TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+                                               TLS_CIPHER_AES_GCM_128_IV_SIZE);
+                               }
+
+                               memcpy(crypto_info.salt, iv.data,
+                               TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+                               memcpy(crypto_info.rec_seq, seq_number,
+                               TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
+                               memcpy(crypto_info.key, cipher_key.data,
+                               TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+
+                               if (setsockopt (sockin, SOL_TLS, TLS_RX,
+                                               &crypto_info, sizeof (crypto_info))) {
+                                       session->internals.ktls_enabled ^= KTLS_RECV;
+                                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                               }
                        }
-
-                       memcpy(crypto_info.salt, iv.data,
-                       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-                       memcpy(crypto_info.rec_seq, seq_number,
-                       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
-                       memcpy(crypto_info.key, cipher_key.data,
-                       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
-
-                       if (setsockopt(session->internals.recv_fd, SOL_TLS, TLS_RX,
-                               &crypto_info, sizeof (crypto_info))) {
-                               return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
-                       }
-               }
-               break;
-               case GNUTLS_CIPHER_AES_256_GCM:
-               {
-                       struct tls12_crypto_info_aes_gcm_256 crypto_info;
-
-                       crypto_info.info.version = version;
-                       crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_256;
-
-                       assert(cipher_key.size == TLS_CIPHER_AES_GCM_256_KEY_SIZE);
-
-                       /* for TLS 1.2 IV is generated in kernel */
-                       if (version == TLS_1_2_VERSION) {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE);
-                       } else {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE
-                                               + TLS_CIPHER_AES_GCM_256_IV_SIZE);
-
-                               memcpy(crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
-                               TLS_CIPHER_AES_GCM_256_IV_SIZE);
-                       }
-
-                       memcpy(crypto_info.salt, iv.data,
-                       TLS_CIPHER_AES_GCM_256_SALT_SIZE);
-                       memcpy(crypto_info.rec_seq, seq_number,
-                       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
-                       memcpy(crypto_info.key, cipher_key.data,
-                       TLS_CIPHER_AES_GCM_256_KEY_SIZE);
-
-                       if (setsockopt(session->internals.recv_fd, SOL_TLS, TLS_RX,
-                               &crypto_info, sizeof(crypto_info))) {
-                               return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                       break;
+                       case GNUTLS_CIPHER_AES_256_GCM:
+                       {
+                               struct tls12_crypto_info_aes_gcm_256 crypto_info;
+                               memset(&crypto_info, 0, sizeof(crypto_info));
+
+                               crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_256;
+                               assert (cipher_key.size == TLS_CIPHER_AES_GCM_256_KEY_SIZE);
+
+                               /* for TLS 1.2 IV is generated in kernel */
+                               if (version == GNUTLS_TLS1_2) {
+                                       crypto_info.info.version = TLS_1_2_VERSION;
+                                       memcpy(crypto_info.iv, seq_number, TLS_CIPHER_AES_GCM_256_IV_SIZE);
+                               } else {
+                                       crypto_info.info.version = TLS_1_3_VERSION;
+                                       assert (iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE
+                                                       + TLS_CIPHER_AES_GCM_256_IV_SIZE);
+
+                                       memcpy(crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+                                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
+                               }
+
+                               memcpy (crypto_info.salt, iv.data,
+                               TLS_CIPHER_AES_GCM_256_SALT_SIZE);
+                               memcpy (crypto_info.rec_seq, seq_number,
+                               TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
+                               memcpy (crypto_info.key, cipher_key.data,
+                               TLS_CIPHER_AES_GCM_256_KEY_SIZE);
+
+                               if (setsockopt (sockin, SOL_TLS, TLS_RX,
+                                               &crypto_info, sizeof (crypto_info))) {
+                                       session->internals.ktls_enabled ^= KTLS_RECV;
+                                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                               }
                        }
+                       break;
+                       default:
+                               assert(0);
                }
-               break;
-               default:
-                       assert(0);
        }
 
-       ret = gnutls_record_get_state(session, 0, &mac_key, &iv, &cipher_key,
+       ret = gnutls_record_get_state (session, 0, &mac_key, &iv, &cipher_key,
                                                                   seq_number);
        if (ret < 0) {
                return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
        }
 
-       switch (cipher) {
-               case GNUTLS_CIPHER_AES_128_GCM:
-               {
-                       struct tls12_crypto_info_aes_gcm_128 crypto_info;
+       if(session->internals.ktls_enabled & KTLS_SEND){
+               switch (cipher) {
+                       case GNUTLS_CIPHER_AES_128_GCM:
+                       {
+                               struct tls12_crypto_info_aes_gcm_128 crypto_info;
+                               memset(&crypto_info, 0, sizeof(crypto_info));
 
-                       crypto_info.info.version = version;
-                       crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
+                               crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
 
-                       assert(cipher_key.size == TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+                               assert (cipher_key.size == TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 
-                       /* for TLS 1.2 IV is generated in kernel */
-                       if (version == TLS_1_2_VERSION) {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-                       } else {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE
-                                               + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+                               /* for TLS 1.2 IV is generated in kernel */
+                               if (version == GNUTLS_TLS1_2) {
+                                       crypto_info.info.version = TLS_1_2_VERSION;
+                                       memcpy(crypto_info.iv, seq_number, TLS_CIPHER_AES_GCM_128_IV_SIZE);
+                               } else {
+                                       crypto_info.info.version = TLS_1_3_VERSION;
+                                       assert (iv.size == TLS_CIPHER_AES_GCM_128_SALT_SIZE
+                                                       + TLS_CIPHER_AES_GCM_128_IV_SIZE);
 
-                               memcpy(crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-                               TLS_CIPHER_AES_GCM_128_IV_SIZE);
-                       }
-
-                       memcpy(crypto_info.salt, iv.data,
-                       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-                       memcpy(crypto_info.rec_seq, seq_number,
-                       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
-                       memcpy(crypto_info.key, cipher_key.data,
-                       TLS_CIPHER_AES_GCM_128_KEY_SIZE);
-
-                       if (setsockopt(session->internals.send_fd, SOL_TLS, TLS_TX,
-                               &crypto_info, sizeof(crypto_info))) {
-                               return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
-                       }
-               }
-               break;
-               case GNUTLS_CIPHER_AES_256_GCM:
-               {
-                       struct tls12_crypto_info_aes_gcm_256 crypto_info;
-
-                       crypto_info.info.version = version;
-                       crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_256;
-                       assert(cipher_key.size == TLS_CIPHER_AES_GCM_256_KEY_SIZE);
-
-                       /* for TLS 1.2 IV is generated in kernel */
-                       if (version == TLS_1_2_VERSION) {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE);
-                       } else {
-                               assert(iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE +
-                                               TLS_CIPHER_AES_GCM_256_IV_SIZE);
-
-                               memcpy(crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
-                               TLS_CIPHER_AES_GCM_256_IV_SIZE);
+                                       memcpy (crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+                                       TLS_CIPHER_AES_GCM_128_IV_SIZE);
+                               }
+
+                               memcpy (crypto_info.salt, iv.data,
+                               TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+                               memcpy (crypto_info.rec_seq, seq_number,
+                               TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
+                               memcpy (crypto_info.key, cipher_key.data,
+                               TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+
+                               if (setsockopt (sockout, SOL_TLS, TLS_TX,
+                                               &crypto_info, sizeof (crypto_info))) {
+                                       session->internals.ktls_enabled ^= KTLS_SEND;
+                                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                               }
                        }
-
-                       memcpy(crypto_info.salt, iv.data,
-                       TLS_CIPHER_AES_GCM_256_SALT_SIZE);
-                       memcpy(crypto_info.rec_seq, seq_number,
-                       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
-                       memcpy(crypto_info.key, cipher_key.data,
-                       TLS_CIPHER_AES_GCM_256_KEY_SIZE);
-
-                       if (setsockopt(session->internals.send_fd, SOL_TLS, TLS_TX,
-                               &crypto_info, sizeof(crypto_info))) {
-                               return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                       break;
+                       case GNUTLS_CIPHER_AES_256_GCM:
+                       {
+                               struct tls12_crypto_info_aes_gcm_256 crypto_info;
+                               memset(&crypto_info, 0, sizeof(crypto_info));
+
+                               crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_256;
+                               assert (cipher_key.size == TLS_CIPHER_AES_GCM_256_KEY_SIZE);
+
+                               /* for TLS 1.2 IV is generated in kernel */
+                               if (version == GNUTLS_TLS1_2) {
+                                       crypto_info.info.version = TLS_1_2_VERSION;
+                                       memcpy(crypto_info.iv, seq_number, TLS_CIPHER_AES_GCM_256_IV_SIZE);
+                               } else {
+                                       crypto_info.info.version = TLS_1_3_VERSION;
+                                       assert (iv.size == TLS_CIPHER_AES_GCM_256_SALT_SIZE +
+                                                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
+
+                                       memcpy (crypto_info.iv, iv.data + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+                                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
+                               }
+
+                               memcpy (crypto_info.salt, iv.data,
+                               TLS_CIPHER_AES_GCM_256_SALT_SIZE);
+                               memcpy (crypto_info.rec_seq, seq_number,
+                               TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
+                               memcpy (crypto_info.key, cipher_key.data,
+                               TLS_CIPHER_AES_GCM_256_KEY_SIZE);
+
+                               if (setsockopt (sockout, SOL_TLS, TLS_TX,
+                                               &crypto_info, sizeof (crypto_info))) {
+                                       session->internals.ktls_enabled ^= KTLS_SEND;
+                                       return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
+                               }
                        }
+                       break;
+                       default:
+                               assert(0);
                }
-               break;
-               default:
-                       assert(0);
-
        }
 
-       session->internals.ktls_enabled = 1;
        return 0;
 }
 
@@ -251,8 +265,11 @@ int _gnutls_ktls_send_control_msg(gnutls_session_t session,
 {
        const char *buf = data;
        ssize_t ret;
+       int sockin, sockout;
 
-       assert(session != NULL);
+       assert (session != NULL);
+
+       gnutls_transport_get_int2(session, &sockin, &sockout);
 
        while (data_size > 0) {
                char cmsg[CMSG_SPACE(sizeof (unsigned char))];
@@ -278,7 +295,7 @@ int _gnutls_ktls_send_control_msg(gnutls_session_t session,
                msg.msg_iov = &msg_iov;
                msg.msg_iovlen = 1;
 
-               ret = sendmsg(session->internals.send_fd, &msg, MSG_DONTWAIT);
+               ret = sendmsg(sockout, &msg, MSG_DONTWAIT);
 
                if (ret == -1) {
                        switch (errno) {
@@ -299,17 +316,20 @@ int _gnutls_ktls_send_control_msg(gnutls_session_t session,
 }
 
 int _gnutls_ktls_recv_control_msg(gnutls_session_t session,
-               unsigned char *record_type, void *data, size_t data_size)
+                       unsigned char *record_type, void *data, size_t data_size)
 {
        char *buf = data;
        ssize_t ret;
+       int sockin, sockout;
 
        char cmsg[CMSG_SPACE(sizeof (unsigned char))];
        struct msghdr msg = { 0 };
        struct iovec msg_iov;
        struct cmsghdr *hdr;
 
-       assert(session != NULL);
+       assert (session != NULL);
+
+       gnutls_transport_get_int2(session, &sockin, &sockout);
 
        if (session->internals.read_eof != 0) {
                return 0;
@@ -327,7 +347,7 @@ int _gnutls_ktls_recv_control_msg(gnutls_session_t session,
        msg.msg_iov = &msg_iov;
        msg.msg_iovlen = 1;
 
-       ret = recvmsg(session->internals.recv_fd, &msg, MSG_DONTWAIT);
+       ret = recvmsg(sockin, &msg, MSG_DONTWAIT);
 
        if (ret == -1){
                switch(errno){
@@ -399,8 +419,11 @@ int _gnutls_ktls_recv_int(gnutls_session_t session, content_type_t type,
 }
 
 #else //ENABLE_KTLS
+int gnutls_transport_is_ktls_enabled(gnutls_session_t session){
+       return gnutls_assert_val(GNUTLS_E_UNIMPLEMENTED_FEATURE);
+}
 
-int _gnutls_ktls_enable(gnutls_session_t session, int sockin, int sockout){
+int _gnutls_ktls_enable(gnutls_session_t session){
        return gnutls_assert_val(GNUTLS_E_UNIMPLEMENTED_FEATURE);
 }
 
index 3955052f58df9fcaa10b47fa21630e3cebdab5e9..829799e21244e0ad40b6380dc28163a2590bdb83 100644 (file)
@@ -3,7 +3,13 @@
 
 #include "gnutls_int.h"
 
-int _gnutls_ktls_enable(gnutls_session_t session, int sockin, int sockout);
+enum{
+       KTLS_RECV = 1,
+       KTLS_SEND,
+       KTLS_DUPLEX,
+};
+
+int _gnutls_ktls_enable(gnutls_session_t session);
 int _gnutls_ktls_set_keys(gnutls_session_t session);
 int _gnutls_ktls_send_control_msg(gnutls_session_t session, unsigned char record_type,
                const void *data, size_t data_size);
index 9482e22b3162e64b8c1006d6fa8d372263e6514d..364f010d04fd2f9fefca45a1227553f08629d9ce 100644 (file)
@@ -43,7 +43,7 @@ static void client_log_func(int level, const char *str)
 }
 
 #define MAX_BUF 1024
-#define MSG "Hello world!"
+#define MSG "Hello world!\0"
 
 
 static void client(int fd, const char *prio)
@@ -63,12 +63,13 @@ static void client(int fd, const char *prio)
        gnutls_certificate_allocate_credentials(&x509_cred);
 
        gnutls_init(&session, GNUTLS_CLIENT);
-       gnutls_handshake_set_timeout(session, get_timeout());
+       gnutls_handshake_set_timeout(session, 0);
+
        assert(gnutls_priority_set_direct(session, prio, NULL) >= 0);
+
        gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred);
+
        gnutls_transport_set_int(session, fd);
-       if (ret < 0)
-               fail("client: error in enabling KTLS: %s\n", gnutls_strerror(ret));
 
        do {
                ret = gnutls_handshake(session);
@@ -77,28 +78,22 @@ static void client(int fd, const char *prio)
 
        if (ret < 0) {
                fail("client: Handshake failed\n");
-               close(fd);
-               gnutls_deinit(session);
-               exit(1);
+               goto end;
        }
        if (debug)
                success("client: Handshake was completed\n");
 
+       ret = gnutls_transport_is_ktls_enabled(session);
+       if (ret != 3){
+               fail("client: KTLS was not properly inicialized\n");
+               goto end;
+       }
+
        /* server send message via gnutls_record_send */
-       int i = 0;
        do{
-               memset(buffer, 0, MAX_BUF + 1);
-               do{
-                       ret = gnutls_record_recv(session, buffer, sizeof(buffer));
-               }
-               while(ret == GNUTLS_E_AGAIN);
-
-               if(strncmp(buffer, MSG+i*MAX_BUF, MAX_BUF))
-                       fail("client: Message doesn't match\n");
-       } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
-
-       if (debug)
-               success ("client: messages received\n");
+               ret = gnutls_record_recv(session, buffer, sizeof(buffer));
+       }
+       while(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
 
        if (ret == 0) {
                        success
@@ -106,13 +101,24 @@ static void client(int fd, const char *prio)
                goto end;
        } else if (ret < 0) {
                fail("client: Error: %s\n", gnutls_strerror(ret));
-               exit(1);
+               goto end;
+       }
+
+       if(strncmp(buffer, MSG, ret)){
+               fail("client: Message doesn't match\n");
+               goto end;
        }
 
+       if (debug)
+               success ("client: messages received\n");
+
+
        ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
        if (ret < 0) {
                fail("client: error in closing session: %s\n", gnutls_strerror(ret));
        }
+
+       ret = 0;
  end:
 
        close(fd);
@@ -122,11 +128,15 @@ static void client(int fd, const char *prio)
        gnutls_certificate_free_credentials(x509_cred);
 
        gnutls_global_deinit();
+
+       if (ret != 0)
+               exit(1);
 }
 
 pid_t child;
 static void terminate(void)
 {
+       assert(child);
        kill(child, SIGTERM);
        exit(1);
 }
@@ -152,7 +162,7 @@ static void server(int fd, const char *prio)
                exit(1);
 
        gnutls_init(&session, GNUTLS_SERVER);
-       gnutls_handshake_set_timeout(session, get_timeout());
+       gnutls_handshake_set_timeout(session, 0);
 
        assert(gnutls_priority_set_direct(session, prio, NULL)>=0);
 
@@ -166,33 +176,34 @@ static void server(int fd, const char *prio)
        while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
 
        if (ret < 0) {
-               close(fd);
-               gnutls_deinit(session);
                fail("server: Handshake has failed (%s)\n\n",
                     gnutls_strerror(ret));
-               terminate();
+               goto end;
        }
        if (debug)
                success("server: Handshake was completed\n");
 
+       ret = gnutls_transport_is_ktls_enabled(session);
+       if (ret != 3){
+               fail("server: KTLS was not properly inicialized\n");
+               goto end;
+       }
        do {
-               ret = gnutls_record_send(session, MSG, strlen(MSG));
+               ret = gnutls_record_send(session, MSG, strlen(MSG)+1);
        } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
 
        if (ret < 0) {
-               close(fd);
-               gnutls_deinit(session);
-               gnutls_certificate_free_credentials(x509_cred);
-               gnutls_global_deinit();
                fail("server: data sending has failed (%s)\n\n",
                     gnutls_strerror(ret));
-               terminate();
+                        goto end;
        }
 
        ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
        if (ret < 0) {
                fail("server: error in closing session: %s\n", gnutls_strerror(ret));
 
+       ret = 0;
+end:
        close(fd);
        gnutls_deinit(session);
 
@@ -200,11 +211,20 @@ static void server(int fd, const char *prio)
 
        gnutls_global_deinit();
 
+       if (ret){
+               terminate();
+       }
+
        if (debug)
                success("server: finished\n");
        }
 }
 
+static void ch_handler(int sig)
+{
+       return;
+}
+
 static void run(const char *prio)
 {
        int ret;
@@ -215,7 +235,7 @@ static void run(const char *prio)
 
        success("running ktls test with %s\n", prio);
 
-       signal(SIGCHLD, SIG_IGN);
+       signal(SIGCHLD, ch_handler);
        signal(SIGPIPE, SIG_IGN);
 
        listener = socket(AF_INET, SOCK_STREAM, 0);
@@ -246,6 +266,7 @@ static void run(const char *prio)
        }
 
        if (child) {
+               int status;
                /* parent */
                ret = listen(listener, 1);
                if (ret == -1) {
@@ -257,7 +278,9 @@ static void run(const char *prio)
                        fail("error in accept(): %s\n", strerror(errno));
                }
                server(fd, prio);
-               kill(child, SIGTERM);
+
+               wait(&status);
+               check_wait_status(status);
        } else {
                fd = socket(AF_INET, SOCK_STREAM, 0);
                if (fd == -1){
@@ -273,10 +296,10 @@ static void run(const char *prio)
 
 void doit(void)
 {
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.2:+AES-128-GCM");
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.2:+AES-256-GCM");
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:+AES-128-GCM");
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:+AES-256-GCM");
+       run("NORMAL:-VERS-ALL:+VERS-TLS1.2:-CIPHER-ALL:+AES-128-GCM");
+       run("NORMAL:-VERS-ALL:+VERS-TLS1.2:-CIPHER-ALL:+AES-256-GCM");
+       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-128-GCM");
+       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-256-GCM");
 }
 
 #endif                         /* _WIN32 */