]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
kTLS: keyupdate_test improvements
authorFrantisek Krenzelok <krenzelok.frantisek@gmail.com>
Tue, 11 Feb 2025 11:45:44 +0000 (12:45 +0100)
committerKrenzelok Frantisek <krenzelok.frantisek@gmail.com>
Fri, 28 Mar 2025 10:18:29 +0000 (11:18 +0100)
- ktls_utils.h has helper funtion to create standard sockets required
  for ktls support testing.
- key_update test for kTLS is now a flavourt of the tls13/key_update
  test instead of being standalone(broadens the testing cases).
- gnutls_ktls.c now uses the aformentioned ktls_utils.h

Signed-off-by: Frantisek Krenzelok <krenzelok.frantisek@gmail.com>
tests/Makefile.am
tests/gnutls_ktls.c
tests/ktls_keyupdate.c [deleted file]
tests/ktls_utils.h [new file with mode: 0644]
tests/tls13/key_update.c

index 72926e9da45552a458d35032b2b7380c19504676..9990ee21cc60dc06ebc5647185171a6413894a0b 100644 (file)
@@ -517,12 +517,6 @@ if ENABLE_TPM2
 dist_check_SCRIPTS += tpm2.sh
 endif
 
-if ENABLE_KTLS
-indirect_tests += gnutls_ktls
-dist_check_SCRIPTS += ktls.sh
-indirect_tests += ktls_keyupdate
-dist_check_SCRIPTS += ktls_keyupdate.sh
-endif
 
 if !WINDOWS
 
@@ -530,6 +524,16 @@ if !WINDOWS
 # List of tests not available/functional under windows
 #
 
+if ENABLE_KTLS
+indirect_tests += gnutls_ktls
+dist_check_SCRIPTS += ktls.sh
+
+indirect_tests += ktls_keyupdate
+ktls_keyupdate_SOURCES = tls13/key_update.c
+ktls_keyupdate_CFLAGS = -DUSE_KTLS
+dist_check_SCRIPTS += ktls_keyupdate.sh
+endif
+
 dist_check_SCRIPTS += dtls/dtls.sh dtls/dtls-resume.sh #dtls/dtls-nb
 
 indirect_tests += dtls-stress
index 90d3e9af91a0d131a9caa711c4ce7ca0bdf47884..ca576d42aacc12ae5337c92bdff58de8e1a6a886 100644 (file)
@@ -31,6 +31,7 @@ int main(void)
 
 #include "cert-common.h"
 #include "utils.h"
+#include "ktls_utils.h"
 
 static void server_log_func(int level, const char *str)
 {
@@ -94,7 +95,8 @@ static void client(int fd, const char *prio)
        } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
 
        if (ret == 0) {
-               success("client: Peer has closed the TLS connection\n");
+               if (debug)
+                       success("client: Peer has closed the TLS connection\n");
                goto end;
        } else if (ret < 0) {
                fail("client: Error: %s\n", gnutls_strerror(ret));
@@ -116,7 +118,8 @@ static void client(int fd, const char *prio)
        } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
 
        if (ret == 0) {
-               success("client: Peer has closed the TLS connection\n");
+               if (debug)
+                       success("client: Peer has closed the TLS connection\n");
                goto end;
        } else if (ret < 0) {
                fail("client: Error: %s\n", gnutls_strerror(ret));
@@ -277,35 +280,16 @@ static void ch_handler(int sig)
 static void run(const char *prio)
 {
        int ret;
-       struct sockaddr_in saddr;
-       socklen_t addrlen;
-       int listener;
-       int fd;
+       int client_fd, server_fd;
 
        success("running ktls test with %s\n", prio);
 
        signal(SIGCHLD, ch_handler);
        signal(SIGPIPE, SIG_IGN);
 
-       listener = socket(AF_INET, SOCK_STREAM, 0);
-       if (listener == -1) {
-               fail("error in listener(): %s\n", strerror(errno));
-       }
-
-       memset(&saddr, 0, sizeof(saddr));
-       saddr.sin_family = AF_INET;
-       saddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
-       saddr.sin_port = 0;
-
-       ret = bind(listener, (struct sockaddr *)&saddr, sizeof(saddr));
-       if (ret == -1) {
-               fail("error in bind(): %s\n", strerror(errno));
-       }
-
-       addrlen = sizeof(saddr);
-       ret = getsockname(listener, (struct sockaddr *)&saddr, &addrlen);
-       if (ret == -1) {
-               fail("error in getsockname(): %s\n", strerror(errno));
+       if ((ret = create_socket_pair(&client_fd, &server_fd))) {
+               fail("Error in socket creation: %d\n", ret);
+               exit(1);
        }
 
        child = fork();
@@ -315,30 +299,20 @@ static void run(const char *prio)
        }
 
        if (child) {
-               int status;
                /* parent */
-               ret = listen(listener, 1);
-               if (ret == -1) {
-                       fail("error in listen(): %s\n", strerror(errno));
-               }
+               int status;
+               close(client_fd);
 
-               fd = accept(listener, NULL, NULL);
-               if (fd == -1) {
-                       fail("error in accept(): %s\n", strerror(errno));
-               }
-               server(fd, prio);
+               server(server_fd, prio);
 
                wait(&status);
                check_wait_status(status);
        } else {
-               fd = socket(AF_INET, SOCK_STREAM, 0);
-               if (fd == -1) {
-                       fail("error in socket(): %s\n", strerror(errno));
-                       exit(1);
-               }
-               usleep(1000000);
-               connect(fd, (struct sockaddr *)&saddr, addrlen);
-               client(fd, prio);
+               close(server_fd);
+               sleep(1);
+
+               client(client_fd, prio);
+
                exit(0);
        }
 }
diff --git a/tests/ktls_keyupdate.c b/tests/ktls_keyupdate.c
deleted file mode 100644 (file)
index b439b7d..0000000
+++ /dev/null
@@ -1,368 +0,0 @@
-// Copyright (C) 2022 Red Hat, Inc.
-//
-// Author: Frantisek Krenzelok
-//
-// This file is part of GnuTLS.
-//
-// GnuTLS is free software; you can redistribute it and/or modify it
-// under the terms of the GNU General Public License as published by the
-// Free Software Foundation; either version 3 of the License, or (at
-// your option) any later version.
-//
-// GnuTLS is distributed in the hope that it will be useful, but
-// WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
-// General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with GnuTLS.  If not, see <https://www.gnu.org/licenses/>.
-
-#ifdef HAVE_CONFIG_H
-#include "config.h"
-#endif
-
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <netinet/in.h>
-#include <sys/socket.h>
-#include <sys/wait.h>
-#include <arpa/inet.h>
-#include <unistd.h>
-#include <gnutls/gnutls.h>
-#include <gnutls/crypto.h>
-#include <gnutls/dtls.h>
-#include <gnutls/socket.h>
-#include <signal.h>
-#include <assert.h>
-#include <errno.h>
-
-#include "cert-common.h"
-#include "utils.h"
-
-#if defined(_WIN32)
-
-int main(void)
-{
-       exit(77);
-}
-
-#else
-
-#define MAX_BUF 1024
-#define MSG "Hello world!"
-
-#define HANDSHAKE(session, name, ret)                                 \
-       {                                                             \
-               do {                                                  \
-                       ret = gnutls_handshake(session);              \
-               } while (ret < 0 && gnutls_error_is_fatal(ret) == 0); \
-               if (ret < 0) {                                        \
-                       fail("%s: Handshake failed\n", name);         \
-                       goto end;                                     \
-               }                                                     \
-       }
-
-#define SEND_MSG(session, name, ret)                                     \
-       {                                                                \
-               do {                                                     \
-                       ret = gnutls_record_send(session, MSG,           \
-                                                strlen(MSG) + 1);       \
-               } while (ret == GNUTLS_E_AGAIN ||                        \
-                        ret == GNUTLS_E_INTERRUPTED);                   \
-               if (ret < 0) {                                           \
-                       fail("%s: data sending has failed (%s)\n", name, \
-                            gnutls_strerror(ret));                      \
-                       goto end;                                        \
-               }                                                        \
-       }
-
-#define RECV_MSG(session, name, buffer, buffer_len, ret)                       \
-       {                                                                      \
-               memset(buffer, 0, sizeof(buffer));                             \
-               do {                                                           \
-                       ret = gnutls_record_recv(session, buffer,              \
-                                                sizeof(buffer));              \
-               } while (ret == GNUTLS_E_AGAIN ||                              \
-                        ret == GNUTLS_E_INTERRUPTED);                         \
-               if (ret == 0) {                                                \
-                       success("%s: Peer has closed the TLS connection\n",    \
-                               name);                                         \
-                       goto end;                                              \
-               } else if (ret < 0) {                                          \
-                       fail("%s: Error -> %s\n", name, gnutls_strerror(ret)); \
-                       goto end;                                              \
-               }                                                              \
-               if (strncmp(buffer, MSG, ret)) {                               \
-                       fail("%s: Message doesn't match\n", name);             \
-                       goto end;                                              \
-               }                                                              \
-       }
-
-#define KEY_UPDATE(session, name, peer_req, ret)                            \
-       {                                                                   \
-               do {                                                        \
-                       ret = gnutls_session_key_update(session, peer_req); \
-               } while (ret == GNUTLS_E_AGAIN ||                           \
-                        ret == GNUTLS_E_INTERRUPTED);                      \
-               if (ret < 0) {                                              \
-                       fail("%s: key update has failed (%s)\n", name,      \
-                            gnutls_strerror(ret));                         \
-                       goto end;                                           \
-               }                                                           \
-       }
-
-#define CHECK_KTLS_ENABLED(session, ret)                                     \
-       {                                                                    \
-               ret = gnutls_transport_is_ktls_enabled(session);             \
-               if (!(ret & GNUTLS_KTLS_RECV)) {                             \
-                       fail("client: KTLS was not properly initialized\n"); \
-                       goto end;                                            \
-               }                                                            \
-       }
-
-static void server_log_func(int level, const char *str)
-{
-       fprintf(stderr, "server|<%d>| %s", level, str);
-}
-
-static void client_log_func(int level, const char *str)
-{
-       fprintf(stderr, "client|<%d>| %s", level, str);
-}
-
-static void client(int fd, const char *prio, int pipe)
-{
-       const char *name = "client";
-       int ret;
-       char foo;
-       char buffer[MAX_BUF + 1];
-       gnutls_certificate_credentials_t x509_cred;
-       gnutls_session_t session;
-
-       global_init();
-
-       if (debug) {
-               gnutls_global_set_log_function(client_log_func);
-               gnutls_global_set_log_level(7);
-       }
-
-       gnutls_certificate_allocate_credentials(&x509_cred);
-
-       gnutls_init(&session, GNUTLS_CLIENT);
-       gnutls_handshake_set_timeout(session, 0);
-
-       assert(gnutls_priority_set_direct(session, prio, NULL) >= 0);
-
-       gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred);
-
-       gnutls_transport_set_int(session, fd);
-
-       HANDSHAKE(session, name, ret);
-
-       CHECK_KTLS_ENABLED(session, ret)
-       // Test 0: Try sending/receiving data
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       SEND_MSG(session, name, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       // Test 1: Servers does key update
-       read(pipe, &foo, 1);
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       SEND_MSG(session, name, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       // Test 2: Does key update witch request
-       read(pipe, &foo, 1);
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       SEND_MSG(session, name, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
-       if (ret < 0) {
-               fail("client: error in closing session: %s\n",
-                    gnutls_strerror(ret));
-       }
-
-       ret = 0;
-end:
-
-       close(fd);
-
-       gnutls_deinit(session);
-
-       gnutls_certificate_free_credentials(x509_cred);
-
-       gnutls_global_deinit();
-
-       if (ret != 0)
-               exit(1);
-}
-
-pid_t child;
-static void terminate(void)
-{
-       assert(child);
-       kill(child, SIGTERM);
-       exit(1);
-}
-
-static void server(int fd, const char *prio, int pipe)
-{
-       const char *name = "server";
-       int ret;
-       char bar = 0;
-       char buffer[MAX_BUF + 1];
-       gnutls_certificate_credentials_t x509_cred;
-       gnutls_session_t session;
-
-       global_init();
-
-       if (debug) {
-               gnutls_global_set_log_function(server_log_func);
-               gnutls_global_set_log_level(7);
-       }
-
-       gnutls_certificate_allocate_credentials(&x509_cred);
-       ret = gnutls_certificate_set_x509_key_mem(
-               x509_cred, &server_cert, &server_key, GNUTLS_X509_FMT_PEM);
-       if (ret < 0)
-               exit(1);
-
-       gnutls_init(&session, GNUTLS_SERVER);
-       gnutls_handshake_set_timeout(session, 0);
-
-       assert(gnutls_priority_set_direct(session, prio, NULL) >= 0);
-
-       gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred);
-
-       gnutls_transport_set_int(session, fd);
-
-       HANDSHAKE(session, name, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       success("Test 0: sending/receiving data\n");
-       SEND_MSG(session, name, ret)
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       success("Test 1: server key update without request\n");
-       KEY_UPDATE(session, name, 0, ret)
-       write(pipe, &bar, 1);
-       SEND_MSG(session, name, ret)
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       success("Test 2: server key update with request\n");
-       KEY_UPDATE(session, name, GNUTLS_KU_PEER, ret)
-       write(pipe, &bar, 1);
-       SEND_MSG(session, name, ret)
-       RECV_MSG(session, name, buffer, MAX_BUF + 1, ret)
-       CHECK_KTLS_ENABLED(session, ret)
-       ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
-       if (ret < 0) {
-               fail("server: error in closing session: %s\n",
-                    gnutls_strerror(ret));
-       }
-
-       ret = 0;
-end:
-       close(fd);
-       gnutls_deinit(session);
-
-       gnutls_certificate_free_credentials(x509_cred);
-
-       gnutls_global_deinit();
-
-       if (ret) {
-               terminate();
-       }
-
-       if (debug)
-               success("server: finished\n");
-}
-
-static void ch_handler(int sig)
-{
-       return;
-}
-
-static void run(const char *prio)
-{
-       int ret;
-       struct sockaddr_in saddr;
-       socklen_t addrlen;
-       int listener;
-       int fd;
-
-       int sync_pipe[2]; //used for synchronization
-       pipe(sync_pipe);
-
-       success("running ktls test with %s\n", prio);
-
-       signal(SIGCHLD, ch_handler);
-       signal(SIGPIPE, SIG_IGN);
-
-       listener = socket(AF_INET, SOCK_STREAM, 0);
-       if (listener == -1) {
-               fail("error in listener(): %s\n", strerror(errno));
-       }
-
-       memset(&saddr, 0, sizeof(saddr));
-       saddr.sin_family = AF_INET;
-       saddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
-       saddr.sin_port = 0;
-
-       ret = bind(listener, (struct sockaddr *)&saddr, sizeof(saddr));
-       if (ret == -1) {
-               fail("error in bind(): %s\n", strerror(errno));
-       }
-
-       addrlen = sizeof(saddr);
-       ret = getsockname(listener, (struct sockaddr *)&saddr, &addrlen);
-       if (ret == -1) {
-               fail("error in getsockname(): %s\n", strerror(errno));
-       }
-
-       child = fork();
-       if (child < 0) {
-               fail("error in fork(): %s\n", strerror(errno));
-               exit(1);
-       }
-
-       if (child) {
-               int status;
-               /* parent */
-               ret = listen(listener, 1);
-               if (ret == -1) {
-                       fail("error in listen(): %s\n", strerror(errno));
-               }
-
-               fd = accept(listener, NULL, NULL);
-               if (fd == -1) {
-                       fail("error in accept(): %s\n", strerror(errno));
-               }
-
-               close(sync_pipe[0]);
-               server(fd, prio, sync_pipe[1]);
-
-               wait(&status);
-               check_wait_status(status);
-       } else {
-               fd = socket(AF_INET, SOCK_STREAM, 0);
-               if (fd == -1) {
-                       fail("error in socket(): %s\n", strerror(errno));
-                       exit(1);
-               }
-
-               usleep(1000000);
-               connect(fd, (struct sockaddr *)&saddr, addrlen);
-
-               close(sync_pipe[1]);
-               client(fd, prio, sync_pipe[0]);
-               exit(0);
-       }
-}
-
-void doit(void)
-{
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-128-GCM");
-       run("NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-256-GCM");
-}
-
-#endif /* _WIN32 */
diff --git a/tests/ktls_utils.h b/tests/ktls_utils.h
new file mode 100644 (file)
index 0000000..231618d
--- /dev/null
@@ -0,0 +1,94 @@
+#ifndef GNUTLS_TESTS_KTLS_UTILS_H
+#define GNUTLS_TESTS_KTLS_UTILS_H
+
+#include <fcntl.h>
+#include <signal.h>
+
+#include <netinet/in.h>
+
+#include <sys/socket.h>
+#include <sys/wait.h>
+
+/* Sets the NONBLOCK flag on the socket(fd) */
+inline static int set_nonblocking(int fd)
+{
+       int flags = fcntl(fd, F_GETFL, 0);
+       if (flags == -1) {
+               return 1;
+       }
+
+       if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
+               return 2;
+       }
+
+       return 0;
+}
+
+/* Creates a pair of TCP connected sockets */
+static int create_socket_pair(int *client_fd, int *server_fd)
+{
+       int ret;
+       struct sockaddr_in saddr;
+       socklen_t addrlen;
+       int listener;
+
+       listener = socket(AF_INET, SOCK_STREAM, 0);
+       if (listener == -1) {
+               fail("error in listener(): %s\n", strerror(errno));
+               return 1;
+       }
+
+       int opt = 0;
+       setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
+
+       memset(&saddr, 0, sizeof(saddr));
+       saddr.sin_family = AF_INET;
+       saddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+       saddr.sin_port = 0;
+
+       ret = bind(listener, (struct sockaddr *)&saddr, sizeof(saddr));
+       if (ret == -1) {
+               fail("error in bind(): %s\n", strerror(errno));
+               return 1;
+       }
+
+       addrlen = sizeof(saddr);
+       ret = getsockname(listener, (struct sockaddr *)&saddr, &addrlen);
+       if (ret == -1) {
+               fail("error in getsockname(): %s\n", strerror(errno));
+               return 1;
+       }
+
+       ret = listen(listener, 1);
+       if (ret == -1) {
+               fail("error in listen(): %s\n", strerror(errno));
+               close(listener);
+               return 1;
+       }
+
+       *client_fd = socket(AF_INET, SOCK_STREAM, 0);
+       if (*client_fd < 0) {
+               fail("error in socket(): %s\n", strerror(errno));
+               return 1;
+       }
+
+       ret = connect(*client_fd, (struct sockaddr *)&saddr, addrlen);
+       if (ret < 0) {
+               fail("error in connect(): %s\n", strerror(errno));
+               close(listener);
+               close(*client_fd);
+               return 1;
+       }
+
+       *server_fd = accept(listener, NULL, NULL);
+       if (*server_fd < 0) {
+               fail("error in accept(): %s\n", strerror(errno));
+               close(listener);
+               close(*client_fd);
+               return 1;
+       }
+
+       return 0;
+}
+
+#endif //GNUTLS_TESTS_KTLS_UTILS_H
index 19bea7b9ac14811825e50dde28b150ba94c3c0b1..9c72ee14ab691266c4486805de0195095ffc04ab 100644 (file)
@@ -47,6 +47,67 @@ static void tls_log_func(int level, const char *str)
 #define MSG \
        "Hello TLS, and hi and how are you and more data here... and more... and even more and even more more data..."
 
+#ifdef USE_KTLS
+/* For ktls flavour of the test, the test needs to be run with the kTLS enabled
+ * option in the GnuTLS configuration file, in the case of the GnuTLS testing,
+ * the shell script `tests/ktls_keyupdate.sh` does just that.
+ */
+
+#include "ktls_utils.h"
+
+#define RUN(s, n)                                         \
+       {                                                 \
+               int ret;                                  \
+               int cfd, sfd;                             \
+               if ((ret = create_sockets(&cfd, &sfd))) { \
+                       fail("kTLS: %s\n", errors[ret]);  \
+                       exit(1);                          \
+               }                                         \
+               run(s, n, cfd, sfd);                      \
+               close(cfd);                               \
+               close(sfd);                               \
+       }
+
+/* We could check specificaly but given how the setting of new keys work, let's
+ * check if something had gone sideways on both the receive and sending sockets.
+ */
+#define CHECK_KTLS_ENABLED(session)                             \
+       switch (gnutls_transport_is_ktls_enabled(session)) {    \
+       case GNUTLS_KTLS_RECV:                                  \
+               fail("kTLS: Only recv support is initiated\n"); \
+               break;                                          \
+       case GNUTLS_KTLS_SEND:                                  \
+               fail("kTLS: Only send support is initiated\n"); \
+               break;                                          \
+       case GNUTLS_KTLS_DUPLEX:                                \
+               break;                                          \
+       default:                                                \
+               fail("kTLS: dissabled\n");                      \
+       }
+
+/* ktls needs to use real sockets */
+char *errors[] = { "", "Failed to create the socket pair",
+                  "Failed to set the socket non-blocking" };
+
+inline static int create_sockets(int *cfd, int *sfd)
+{
+       int ret = create_socket_pair(cfd, sfd);
+       if (ret)
+               return 1;
+
+       if (set_nonblocking(*cfd) || set_nonblocking(*sfd))
+               return 2;
+
+       return 0;
+}
+
+#else /* non-kTLS */
+
+#define RUN(s, n) run(s, n, -1, -1)
+#define CHECK_KTLS_ENABLED(session) /* No-op */
+
+#endif /* non-kTLS */
+
 static unsigned key_update_msg_inc = 0;
 static unsigned key_update_msg_out = 0;
 
@@ -68,7 +129,7 @@ static int hsk_callback(gnutls_session_t session, unsigned int htype,
        return 0;
 }
 
-static void run(const char *name, unsigned test)
+static void run(const char *name, unsigned test, int cfd, int sfd)
 {
        /* Server stuff. */
        gnutls_certificate_credentials_t ccred;
@@ -100,9 +161,14 @@ static void run(const char *name, unsigned test)
                exit(1);
 
        gnutls_credentials_set(server, GNUTLS_CRD_CERTIFICATE, scred);
+#ifdef USE_KTLS
+       /* for kTLS you can't use custom push/pull functions */
+       gnutls_transport_set_int(server, sfd);
+#else
        gnutls_transport_set_push_function(server, server_push);
        gnutls_transport_set_pull_function(server, server_pull);
        gnutls_transport_set_ptr(server, server);
+#endif
 
        /* Init client */
        assert(gnutls_certificate_allocate_credentials(&ccred) >= 0);
@@ -118,14 +184,23 @@ static void run(const char *name, unsigned test)
        if (ret < 0)
                exit(1);
 
+#ifdef USE_KTLS
+       /* for kTLS you can't use custom push/pull functions */
+       gnutls_transport_set_int(client, cfd);
+#else
        gnutls_transport_set_push_function(client, client_push);
        gnutls_transport_set_pull_function(client, client_pull);
        gnutls_transport_set_ptr(client, client);
+#endif
 
        HANDSHAKE(client, server);
        if (debug)
                success("Handshake established\n");
 
+       /* check kTLS initialization (both send and recv should be supported)*/
+       CHECK_KTLS_ENABLED(client);
+       CHECK_KTLS_ENABLED(server);
+
        switch (test) {
        case 0:
        case 1:
@@ -134,10 +209,14 @@ static void run(const char *name, unsigned test)
                        ret = gnutls_session_key_update(client, 0);
                } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
 
+               CHECK_KTLS_ENABLED(client);
+               CHECK_KTLS_ENABLED(server);
+
                /* server receives the client key update and sends data */
                TRANSFER(client, server, MSG, strlen(MSG), buffer, MAX_BUF);
                TRANSFER(server, client, MSG, strlen(MSG), buffer, MAX_BUF);
                EMPTY_BUF(server, client, buffer, MAX_BUF);
+
                if (test != 0)
                        break;
                sec_sleep(2);
@@ -151,10 +230,14 @@ static void run(const char *name, unsigned test)
                if (ret < 0)
                        fail("error in key update: %s\n", gnutls_strerror(ret));
 
+               CHECK_KTLS_ENABLED(server);
+               CHECK_KTLS_ENABLED(client);
+
                /* client receives the key update and sends data */
                TRANSFER(client, server, MSG, strlen(MSG), buffer, MAX_BUF);
                TRANSFER(server, client, MSG, strlen(MSG), buffer, MAX_BUF);
                EMPTY_BUF(server, client, buffer, MAX_BUF);
+
                if (test != 0)
                        break;
                sec_sleep(2);
@@ -167,10 +250,14 @@ static void run(const char *name, unsigned test)
                if (ret < 0)
                        fail("error in key update: %s\n", gnutls_strerror(ret));
 
+               CHECK_KTLS_ENABLED(client);
+               CHECK_KTLS_ENABLED(server);
+
                /* server receives the client key update and sends data */
                TRANSFER(client, server, MSG, strlen(MSG), buffer, MAX_BUF);
                TRANSFER(server, client, MSG, strlen(MSG), buffer, MAX_BUF);
                EMPTY_BUF(server, client, buffer, MAX_BUF);
+
                if (test != 0)
                        break;
                sec_sleep(2);
@@ -183,6 +270,9 @@ static void run(const char *name, unsigned test)
                if (ret < 0)
                        fail("error in key update: %s\n", gnutls_strerror(ret));
 
+               CHECK_KTLS_ENABLED(server);
+               CHECK_KTLS_ENABLED(client);
+
                TRANSFER(client, server, MSG, strlen(MSG), buffer, MAX_BUF);
                TRANSFER(server, client, MSG, strlen(MSG), buffer, MAX_BUF);
                EMPTY_BUF(server, client, buffer, MAX_BUF);
@@ -210,6 +300,9 @@ static void run(const char *name, unsigned test)
                /* client receives key update */
                EMPTY_BUF(server, client, buffer, MAX_BUF);
 
+               CHECK_KTLS_ENABLED(server);
+               CHECK_KTLS_ENABLED(client);
+
                /* client uncorks and sends key update */
                do {
                        ret = gnutls_record_uncork(client, GNUTLS_RECORD_WAIT);
@@ -217,6 +310,9 @@ static void run(const char *name, unsigned test)
                if (ret < 0)
                        fail("cannot send: %s\n", gnutls_strerror(ret));
 
+               CHECK_KTLS_ENABLED(server);
+               CHECK_KTLS_ENABLED(client);
+
                EMPTY_BUF(server, client, buffer, MAX_BUF);
 
                sec_sleep(2);
@@ -238,6 +334,9 @@ static void run(const char *name, unsigned test)
                if (ret < 0)
                        fail("error in key update: %s\n", gnutls_strerror(ret));
 
+               CHECK_KTLS_ENABLED(server);
+               CHECK_KTLS_ENABLED(client);
+
                /* server receives the client key update and sends data */
                TRANSFER(client, server, MSG, strlen(MSG), buffer, MAX_BUF);
                TRANSFER(server, client, MSG, strlen(MSG), buffer, MAX_BUF);
@@ -263,11 +362,11 @@ static void run(const char *name, unsigned test)
 
 void doit(void)
 {
-       run("single", 1);
-       run("single", 2);
-       run("single", 3);
-       run("single", 4);
-       run("single", 5);
-       run("single", 6);
-       run("all", 0); /* all one after each other */
+       RUN("single", 1);
+       RUN("single", 2);
+       RUN("single", 3);
+       RUN("single", 4);
+       RUN("single", 5);
+       RUN("single", 6);
+       RUN("all", 0); /* all one after each other */
 }