From: Anthony Brandon Date: Fri, 30 May 2025 14:47:21 +0000 (+0200) Subject: tls: move gnutls code into tls_gnutls.c X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3e32e7e69412042f2477c02319e36eab6aee133d;p=thirdparty%2Fchrony.git tls: move gnutls code into tls_gnutls.c Currently nts_ke_session.c directly calls into gnutls. This patch moves the calls to gnutls into tls_gnutls.c with an API defined in tls.h. This way it becomes possible to use different TLS implementations in future patches. Signed-off-by: Anthony Brandon --- diff --git a/configure b/configure index bcd69449..7f978c61 100755 --- a/configure +++ b/configure @@ -1011,7 +1011,7 @@ if [ $feat_nts = "1" ] && [ $try_gnutls = "1" ]; then fi if grep '#define HAVE_SIV' config.h > /dev/null; then - EXTRA_OBJECTS="$EXTRA_OBJECTS nts_ke_client.o nts_ke_server.o nts_ke_session.o" + EXTRA_OBJECTS="$EXTRA_OBJECTS nts_ke_client.o nts_ke_server.o nts_ke_session.o tls_gnutls.o" EXTRA_OBJECTS="$EXTRA_OBJECTS nts_ntp_auth.o nts_ntp_client.o nts_ntp_server.o" LIBS="$LIBS $test_link" MYCPPFLAGS="$MYCPPFLAGS $test_cflags" diff --git a/nts_ke_session.c b/nts_ke_session.c index c5f49d3d..7e09ac5f 100644 --- a/nts_ke_session.c +++ b/nts_ke_session.c @@ -37,11 +37,9 @@ #include "siv.h" #include "socket.h" #include "sched.h" +#include "tls.h" #include "util.h" -#include -#include - #define INVALID_SOCK_FD (-8) struct RecordHeader { @@ -75,7 +73,7 @@ struct NKSN_Instance_Record { KeState state; int sock_fd; char *label; - gnutls_session_t tls_session; + TLS_Instance tls_session; SCH_TimeoutID timeout_id; int retry_factor; @@ -85,8 +83,6 @@ struct NKSN_Instance_Record { /* ================================================== */ -static gnutls_priority_t priority_cache; - static int credentials_counter = 0; static int clock_updates = 0; @@ -206,71 +202,6 @@ check_message_format(struct Message *message, int eof) /* ================================================== */ -static gnutls_session_t -create_tls_session(int server_mode, int sock_fd, const char *server_name, - gnutls_certificate_credentials_t credentials, - gnutls_priority_t priority) -{ - unsigned char alpn_name[sizeof (NKE_ALPN_NAME)]; - gnutls_session_t session; - gnutls_datum_t alpn; - unsigned int flags; - int r; - - r = gnutls_init(&session, GNUTLS_NONBLOCK | GNUTLS_NO_TICKETS | - (server_mode ? GNUTLS_SERVER : GNUTLS_CLIENT)); - if (r < 0) { - LOG(LOGS_ERR, "Could not %s TLS session : %s", "create", gnutls_strerror(r)); - return NULL; - } - - if (!server_mode) { - assert(server_name); - - if (!UTI_IsStringIP(server_name)) { - r = gnutls_server_name_set(session, GNUTLS_NAME_DNS, server_name, strlen(server_name)); - if (r < 0) - goto error; - } - - flags = 0; - - if (clock_updates < CNF_GetNoCertTimeCheck()) { - flags |= GNUTLS_VERIFY_DISABLE_TIME_CHECKS | GNUTLS_VERIFY_DISABLE_TRUSTED_TIME_CHECKS; - DEBUG_LOG("Disabled time checks"); - } - - gnutls_session_set_verify_cert(session, server_name, flags); - } - - r = gnutls_priority_set(session, priority); - if (r < 0) - goto error; - - r = gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, credentials); - if (r < 0) - goto error; - - memcpy(alpn_name, NKE_ALPN_NAME, sizeof (alpn_name)); - alpn.data = alpn_name; - alpn.size = sizeof (alpn_name) - 1; - - r = gnutls_alpn_set_protocols(session, &alpn, 1, 0); - if (r < 0) - goto error; - - gnutls_transport_set_int(session, sock_fd); - - return session; - -error: - LOG(LOGS_ERR, "Could not %s TLS session : %s", "set", gnutls_strerror(r)); - gnutls_deinit(session); - return NULL; -} - -/* ================================================== */ - static void stop_session(NKSN_Instance inst) { @@ -286,7 +217,7 @@ stop_session(NKSN_Instance inst) Free(inst->label); inst->label = NULL; - gnutls_deinit(inst->tls_session); + TLS_DestroyInstance(inst->tls_session); inst->tls_session = NULL; SCH_RemoveTimeout(inst->timeout_id); @@ -308,21 +239,6 @@ session_timeout(void *arg) /* ================================================== */ -static int -check_alpn(NKSN_Instance inst) -{ - gnutls_datum_t alpn; - - if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 || - alpn.size != sizeof (NKE_ALPN_NAME) - 1 || - memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0) - return 0; - - return 1; -} - -/* ================================================== */ - static void set_input_output(NKSN_Instance inst, int output) { @@ -364,6 +280,7 @@ static int handle_event(NKSN_Instance inst, int event) { struct Message *message = &inst->message; + TLS_Status s; int r; DEBUG_LOG("Session event %d fd=%d state=%d", event, inst->sock_fd, (int)inst->state); @@ -390,56 +307,28 @@ handle_event(NKSN_Instance inst, int event) return 0; case KE_HANDSHAKE: - r = gnutls_handshake(inst->tls_session); - - if (r < 0) { - if (gnutls_error_is_fatal(r)) { - gnutls_datum_t cert_error; - - /* Get a description of verification errors */ - if (r != GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR || - gnutls_certificate_verification_status_print( - gnutls_session_get_verify_cert_status(inst->tls_session), - gnutls_certificate_type_get(inst->tls_session), &cert_error, 0) < 0) - cert_error.data = NULL; - - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "TLS handshake with %s failed : %s%s%s", inst->label, gnutls_strerror(r), - cert_error.data ? " " : "", cert_error.data ? (const char *)cert_error.data : ""); - - if (cert_error.data) - gnutls_free(cert_error.data); - + s = TLS_DoHandshake(inst->tls_session); + + switch (s) { + case TLS_SUCCESS: + break; + case TLS_AGAIN_OUTPUT: + case TLS_AGAIN_INPUT: + set_input_output(inst, s == TLS_AGAIN_OUTPUT); + return 0; + default: stop_session(inst); /* Increase the retry interval if the handshake did not fail due to the other end closing the connection */ - if (r != GNUTLS_E_PULL_ERROR && r != GNUTLS_E_PREMATURE_TERMINATION) + if (s != TLS_CLOSED) inst->retry_factor = NKE_RETRY_FACTOR2_TLS; return 0; - } - - /* Disable output when the handshake is trying to receive data */ - set_input_output(inst, gnutls_record_get_direction(inst->tls_session)); - return 0; } inst->retry_factor = NKE_RETRY_FACTOR2_TLS; - if (DEBUG) { - char *description = gnutls_session_get_desc(inst->tls_session); - DEBUG_LOG("Handshake with %s completed %s", - inst->label, description ? description : ""); - gnutls_free(description); - } - - if (!check_alpn(inst)) { - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE not supported by %s", inst->label); - stop_session(inst); - return 0; - } - /* Client will send a request to the server */ change_state(inst, inst->server ? KE_RECEIVE : KE_SEND); return 0; @@ -448,16 +337,17 @@ handle_event(NKSN_Instance inst, int event) assert(inst->new_message && message->complete); assert(message->length <= sizeof (message->data) && message->length > message->sent); - r = gnutls_record_send(inst->tls_session, &message->data[message->sent], - message->length - message->sent); + s = TLS_Send(inst->tls_session, &message->data[message->sent], + message->length - message->sent, &r); - if (r < 0) { - if (gnutls_error_is_fatal(r)) { - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "Could not send NTS-KE message to %s : %s", inst->label, gnutls_strerror(r)); + switch (s) { + case TLS_SUCCESS: + break; + case TLS_AGAIN_OUTPUT: + return 0; + default: stop_session(inst); - } - return 0; + return 0; } DEBUG_LOG("Sent %d bytes to %s", r, inst->label); @@ -480,26 +370,24 @@ handle_event(NKSN_Instance inst, int event) return 0; } - r = gnutls_record_recv(inst->tls_session, &message->data[message->length], - sizeof (message->data) - message->length); + s = TLS_Receive(inst->tls_session, &message->data[message->length], + sizeof (message->data) - message->length, &r); - if (r < 0) { - /* Handle a renegotiation request on both client and server as - a protocol error */ - if (gnutls_error_is_fatal(r) || r == GNUTLS_E_REHANDSHAKE) { - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "Could not receive NTS-KE message from %s : %s", - inst->label, gnutls_strerror(r)); + switch (s) { + case TLS_SUCCESS: + break; + case TLS_AGAIN_INPUT: + return 0; + default: stop_session(inst); - } - return 0; + return 0; } DEBUG_LOG("Received %d bytes from %s", r, inst->label); message->length += r; - } while (gnutls_record_check_pending(inst->tls_session) > 0); + } while (TLS_CheckPending(inst->tls_session)); if (!check_message_format(message, r == 0)) { LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, @@ -519,18 +407,18 @@ handle_event(NKSN_Instance inst, int event) return 1; case KE_SHUTDOWN: - r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR); - - if (r < 0) { - if (gnutls_error_is_fatal(r)) { - DEBUG_LOG("Shutdown with %s failed : %s", inst->label, gnutls_strerror(r)); + s = TLS_Shutdown(inst->tls_session); + + switch (s) { + case TLS_SUCCESS: + break; + case TLS_AGAIN_OUTPUT: + case TLS_AGAIN_INPUT: + set_input_output(inst, s == TLS_AGAIN_OUTPUT); + return 0; + default: stop_session(inst); return 0; - } - - /* Disable output when the TLS shutdown is trying to receive data */ - set_input_output(inst, gnutls_record_get_direction(inst->tls_session)); - return 0; } SCK_ShutdownConnection(inst->sock_fd); @@ -592,36 +480,18 @@ handle_step(struct timespec *raw, struct timespec *cooked, double dfreq, /* ================================================== */ -static int gnutls_initialised = 0; +static int tls_initialised = 0; static int -init_gnutls(void) +init_tls(void) { - int r; - - if (gnutls_initialised) + if (tls_initialised) return 1; - r = gnutls_global_init(); - if (r < 0) - LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r)); - - /* Prepare a priority cache for server and client NTS-KE sessions - (the NTS specification requires TLS1.3 or later) */ - r = gnutls_priority_init2(&priority_cache, - "-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:-VERS-DTLS-ALL", - NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND); - if (r < 0) { - LOG(LOGS_ERR, "Could not initialise %s : %s", - "priority cache for TLS", gnutls_strerror(r)); - gnutls_global_deinit(); + if (!TLS_Initialise(&get_time)) return 0; - } - /* Use our clock instead of the system clock in certificate verification */ - gnutls_global_set_time_function(get_time); - - gnutls_initialised = 1; + tls_initialised = 1; DEBUG_LOG("Initialised"); LCL_AddParameterChangeHandler(handle_step, NULL); @@ -632,16 +502,15 @@ init_gnutls(void) /* ================================================== */ static void -deinit_gnutls(void) +deinit_tls(void) { - if (!gnutls_initialised || credentials_counter > 0) + if (!tls_initialised || credentials_counter > 0) return; LCL_RemoveParameterChangeHandler(handle_step, NULL); - gnutls_priority_deinit(priority_cache); - gnutls_global_deinit(); - gnutls_initialised = 0; + TLS_Finalise(); + tls_initialised = 0; DEBUG_LOG("Deinitialised"); } @@ -652,67 +521,21 @@ create_credentials(const char **certs, const char **keys, int n_certs_keys, const char **trusted_certs, uint32_t *trusted_certs_ids, int n_trusted_certs, uint32_t trusted_cert_set) { - gnutls_certificate_credentials_t credentials = NULL; - int i, r; + TLS_Credentials credentials; - if (!init_gnutls()) + if (!init_tls()) return NULL; - r = gnutls_certificate_allocate_credentials(&credentials); - if (r < 0) - goto error; - - if (certs && keys) { - BRIEF_ASSERT(!trusted_certs && !trusted_certs_ids); - - for (i = 0; i < n_certs_keys; i++) { - if (!UTI_CheckFilePermissions(keys[i], 0771)) - ; - r = gnutls_certificate_set_x509_key_file(credentials, certs[i], keys[i], - GNUTLS_X509_FMT_PEM); - if (r < 0) - goto error; - } - } else { - BRIEF_ASSERT(!certs && !keys && n_certs_keys <= 0); - - if (trusted_cert_set == 0 && !CNF_GetNoSystemCert()) { - r = gnutls_certificate_set_x509_system_trust(credentials); - if (r < 0) - goto error; - } - - if (trusted_certs && trusted_certs_ids) { - for (i = 0; i < n_trusted_certs; i++) { - struct stat buf; - - if (trusted_certs_ids[i] != trusted_cert_set) - continue; - - if (stat(trusted_certs[i], &buf) == 0 && S_ISDIR(buf.st_mode)) - r = gnutls_certificate_set_x509_trust_dir(credentials, trusted_certs[i], - GNUTLS_X509_FMT_PEM); - else - r = gnutls_certificate_set_x509_trust_file(credentials, trusted_certs[i], - GNUTLS_X509_FMT_PEM); - if (r < 0) - goto error; - - DEBUG_LOG("Added %d trusted certs from %s", r, trusted_certs[i]); - } - } + credentials = TLS_CreateCredentials(certs, keys, n_certs_keys, trusted_certs, + trusted_certs_ids, n_trusted_certs, trusted_cert_set); + if (!credentials) { + deinit_tls(); + return NULL; } credentials_counter++; - return (NKSN_Credentials)credentials; - -error: - LOG(LOGS_ERR, "Could not set credentials : %s", gnutls_strerror(r)); - if (credentials) - gnutls_certificate_free_credentials(credentials); - deinit_gnutls(); - return NULL; + return credentials; } /* ================================================== */ @@ -737,9 +560,9 @@ NKSN_CreateClientCertCredentials(const char **certs, uint32_t *ids, void NKSN_DestroyCertCredentials(NKSN_Credentials credentials) { - gnutls_certificate_free_credentials((gnutls_certificate_credentials_t)credentials); + TLS_DestroyCredentials(credentials); credentials_counter--; - deinit_gnutls(); + deinit_tls(); } /* ================================================== */ @@ -789,9 +612,9 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label, { assert(inst->state == KE_STOPPED); - inst->tls_session = create_tls_session(inst->server, sock_fd, inst->server_name, - (gnutls_certificate_credentials_t)credentials, - priority_cache); + inst->tls_session = TLS_CreateInstance(inst->server, sock_fd, inst->server_name, + NKE_ALPN_NAME, credentials, + clock_updates < CNF_GetNoCertTimeCheck()); if (!inst->tls_session) return 0; @@ -899,19 +722,15 @@ NKSN_GetKeys(NKSN_Instance inst, SIV_Algorithm algorithm, SIV_Algorithm exporter context.algorithm = htons(exporter_algorithm); context.is_s2c = 0; - if (gnutls_prf_rfc5705(inst->tls_session, - sizeof (NKE_EXPORTER_LABEL) - 1, NKE_EXPORTER_LABEL, - sizeof (context) - 1, (char *)&context, - length, (char *)c2s->key) < 0) { + if (!TLS_ExportKey(inst->tls_session, sizeof (NKE_EXPORTER_LABEL) - 1, NKE_EXPORTER_LABEL, + sizeof (context) - 1, &context, length, c2s->key)) { DEBUG_LOG("Could not export key"); return 0; } context.is_s2c = 1; - if (gnutls_prf_rfc5705(inst->tls_session, - sizeof (NKE_EXPORTER_LABEL) - 1, NKE_EXPORTER_LABEL, - sizeof (context) - 1, (char *)&context, - length, (char *)s2c->key) < 0) { + if (!TLS_ExportKey(inst->tls_session, sizeof (NKE_EXPORTER_LABEL) - 1, NKE_EXPORTER_LABEL, + sizeof (context) - 1, &context, length, s2c->key)) { DEBUG_LOG("Could not export key"); return 0; } diff --git a/tls.h b/tls.h new file mode 100644 index 00000000..0a283821 --- /dev/null +++ b/tls.h @@ -0,0 +1,93 @@ +/* + chronyd/chronyc - Programs for keeping computer clocks accurate. + + ********************************************************************** + * Copyright (C) Anthony Brandon 2025 + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + * + ********************************************************************** + + ======================================================================= + + Header file for the TLS session + */ + +#ifndef GOT_TLS_H +#define GOT_TLS_H + +struct TLS_Instance_Record; + +typedef struct TLS_Instance_Record *TLS_Instance; + +typedef void *TLS_Credentials; + +typedef enum { + /* TLS operation succeeded */ + TLS_SUCCESS, + /* TLS operation failed. + No more operations should be called and the session should be destroyed. */ + TLS_FAILED, + /* TLS session closed by other end */ + TLS_CLOSED, + /* The last TLS operation should be called again when input is ready */ + TLS_AGAIN_INPUT, + /* The last TLS operation should be called again when output is ready */ + TLS_AGAIN_OUTPUT, +} TLS_Status; + +/* Initialize TLS */ +extern int TLS_Initialise(time_t (*get_time)(time_t *t)); + +/* Deinitialize TLS */ +extern void TLS_Finalise(void); + +/* Create new TLS credentials instance */ +extern TLS_Credentials TLS_CreateCredentials(const char **certs, const char **keys, + int n_certs_keys, const char **trusted_certs, + uint32_t * trusted_certs_ids, int n_trusted_certs, + uint32_t trusted_cert_set); + +/* Destroy TLS credentials instance */ +extern void TLS_DestroyCredentials(TLS_Credentials credentials); + +/* Create new TLS session instance */ +extern TLS_Instance TLS_CreateInstance(int server_mode, int sock_fd, const char *server_name, + const char *alpn_name, TLS_Credentials credentials, + int disable_time_checks); + +/* Destroy TLS instance */ +extern void TLS_DestroyInstance(TLS_Instance inst); + +/* Perform TLS handshake */ +extern TLS_Status TLS_DoHandshake(TLS_Instance inst); + +/* Send data over TLS */ +extern TLS_Status TLS_Send(TLS_Instance inst, const void *data, int length, int *sent); + +/* Receive data over TLS */ +extern TLS_Status TLS_Receive(TLS_Instance inst, void *data, int length, int *received); + +/* Check if there is data pending to read */ +extern int TLS_CheckPending(TLS_Instance inst); + +/* Perform TLS shutdown */ +extern TLS_Status TLS_Shutdown(TLS_Instance inst); + +/* Export key from TLS instance */ +extern int TLS_ExportKey(TLS_Instance inst, int label_length, const char *label, + int context_length, const void *context, int key_length, + unsigned char *key); + +#endif diff --git a/tls_gnutls.c b/tls_gnutls.c new file mode 100644 index 00000000..2def9ee0 --- /dev/null +++ b/tls_gnutls.c @@ -0,0 +1,413 @@ +/* + chronyd/chronyc - Programs for keeping computer clocks accurate. + + ********************************************************************** + * Copyright (C) Miroslav Lichvar 2020-2021 + * Copyright (C) Anthony Brandon 2025 + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + * + ********************************************************************** + + ======================================================================= + + Routines implementing TLS session handling using the gnutls library. + */ + +#include "config.h" + +#include "sysincl.h" + +#include "tls.h" + +#include "conf.h" +#include "logging.h" +#include "memory.h" +#include "util.h" + +#include +#include + +struct TLS_Instance_Record { + gnutls_session_t session; + int server; + char *server_name; + char *alpn_name; +}; + +/* ================================================== */ + +static gnutls_priority_t priority_cache; + +/* ================================================== */ + +int +TLS_Initialise(time_t (*get_time)(time_t *t)) +{ + int r = gnutls_global_init(); + + if (r < 0) + LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r)); + + /* Prepare a priority cache for server and client NTS-KE sessions + (the NTS specification requires TLS1.3 or later) */ + r = gnutls_priority_init2(&priority_cache, + "-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:-VERS-DTLS-ALL", + NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND); + if (r < 0) { + LOG(LOGS_ERR, "Could not initialise %s : %s", + "priority cache for TLS", gnutls_strerror(r)); + gnutls_global_deinit(); + return 0; + } + + /* Use our clock instead of the system clock in certificate verification */ + gnutls_global_set_time_function(get_time); + + return 1; +} + +/* ================================================== */ + +void +TLS_Finalise(void) +{ + gnutls_priority_deinit(priority_cache); + gnutls_global_deinit(); +} + +/* ================================================== */ + +TLS_Credentials +TLS_CreateCredentials(const char **certs, const char **keys, int n_certs_keys, + const char **trusted_certs, uint32_t *trusted_certs_ids, + int n_trusted_certs, uint32_t trusted_cert_set) +{ + gnutls_certificate_credentials_t credentials = NULL; + int i, r; + + r = gnutls_certificate_allocate_credentials(&credentials); + if (r < 0) + goto error; + + if (certs && keys) { + BRIEF_ASSERT(!trusted_certs && !trusted_certs_ids); + + for (i = 0; i < n_certs_keys; i++) { + if (!UTI_CheckFilePermissions(keys[i], 0771)) + ; + r = gnutls_certificate_set_x509_key_file(credentials, certs[i], keys[i], + GNUTLS_X509_FMT_PEM); + if (r < 0) + goto error; + } + } else { + BRIEF_ASSERT(!certs && !keys && n_certs_keys <= 0); + + if (trusted_cert_set == 0 && !CNF_GetNoSystemCert()) { + r = gnutls_certificate_set_x509_system_trust(credentials); + if (r < 0) + goto error; + } + + if (trusted_certs && trusted_certs_ids) { + for (i = 0; i < n_trusted_certs; i++) { + struct stat buf; + + if (trusted_certs_ids[i] != trusted_cert_set) + continue; + + if (stat(trusted_certs[i], &buf) == 0 && S_ISDIR(buf.st_mode)) + r = gnutls_certificate_set_x509_trust_dir(credentials, trusted_certs[i], + GNUTLS_X509_FMT_PEM); + else + r = gnutls_certificate_set_x509_trust_file(credentials, trusted_certs[i], + GNUTLS_X509_FMT_PEM); + if (r < 0) + goto error; + + DEBUG_LOG("Added %d trusted certs from %s", r, trusted_certs[i]); + } + } + } + + return credentials; + +error: + LOG(LOGS_ERR, "Could not set credentials : %s", gnutls_strerror(r)); + if (credentials) + gnutls_certificate_free_credentials(credentials); + return NULL; +} + +/* ================================================== */ + +void +TLS_DestroyCredentials(TLS_Credentials credentials) +{ + gnutls_certificate_free_credentials((gnutls_certificate_credentials_t)credentials); +} + +/* ================================================== */ + +TLS_Instance +TLS_CreateInstance(int server_mode, int sock_fd, const char *server_name, + const char *alpn_name, TLS_Credentials credentials, int disable_time_checks) +{ + gnutls_datum_t alpn; + unsigned int flags; + int r; + + TLS_Instance inst = MallocNew(struct TLS_Instance_Record); + + inst->session = NULL; + inst->server = server_mode; + inst->server_name = server_name ? Strdup(server_name) : NULL; + inst->alpn_name = alpn_name ? Strdup(alpn_name) : NULL; + + r = gnutls_init(&inst->session, GNUTLS_NONBLOCK | GNUTLS_NO_TICKETS | + (server_mode ? GNUTLS_SERVER : GNUTLS_CLIENT)); + if (r < 0) { + LOG(LOGS_ERR, "Could not %s TLS session : %s", "create", gnutls_strerror(r)); + goto error; + } + + if (!server_mode) { + assert(server_name); + + if (!UTI_IsStringIP(server_name)) { + r = gnutls_server_name_set(inst->session, GNUTLS_NAME_DNS, server_name, + strlen(server_name)); + if (r < 0) + goto error; + } + + flags = 0; + + if (disable_time_checks) { + flags |= GNUTLS_VERIFY_DISABLE_TIME_CHECKS | GNUTLS_VERIFY_DISABLE_TRUSTED_TIME_CHECKS; + DEBUG_LOG("Disabled time checks"); + } + + gnutls_session_set_verify_cert(inst->session, server_name, flags); + } + + r = gnutls_priority_set(inst->session, priority_cache); + if (r < 0) + goto error; + + r = gnutls_credentials_set(inst->session, GNUTLS_CRD_CERTIFICATE, credentials); + if (r < 0) + goto error; + + alpn.data = (unsigned char *)inst->alpn_name; + alpn.size = strlen(inst->alpn_name); + + r = gnutls_alpn_set_protocols(inst->session, &alpn, 1, 0); + if (r < 0) + goto error; + + gnutls_transport_set_int(inst->session, sock_fd); + + return inst; + +error: + LOG(LOGS_ERR, "Could not %s TLS session : %s", "set", gnutls_strerror(r)); + TLS_DestroyInstance(inst); + return NULL; +} + +/* ================================================== */ + +void +TLS_DestroyInstance(TLS_Instance inst) +{ + if (inst->session) + gnutls_deinit(inst->session); + + if (inst->server_name) + Free(inst->server_name); + + if (inst->alpn_name) + Free(inst->alpn_name); + + Free(inst); +} + +/* ================================================== */ + +static int +check_alpn(TLS_Instance inst) +{ + gnutls_datum_t alpn; + int length = strlen(inst->alpn_name); + + if (gnutls_alpn_get_selected_protocol(inst->session, &alpn) < 0 || + alpn.size != length || memcmp(alpn.data, inst->alpn_name, length) != 0) + return 0; + + return 1; +} + +/* ================================================== */ + +TLS_Status +TLS_DoHandshake(TLS_Instance inst) +{ + int r = gnutls_handshake(inst->session); + + if (r < 0) { + if (gnutls_error_is_fatal(r)) { + gnutls_datum_t cert_error; + + /* Get a description of verification errors */ + if (r != GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR || + gnutls_certificate_verification_status_print( + gnutls_session_get_verify_cert_status(inst->session), + gnutls_certificate_type_get(inst->session), &cert_error, 0) < 0) + cert_error.data = NULL; + + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, + "TLS handshake with %s failed : %s%s%s", inst->server_name, gnutls_strerror(r), + cert_error.data ? " " : "", cert_error.data ? (const char *)cert_error.data : ""); + + if (cert_error.data) + gnutls_free(cert_error.data); + + /* Increase the retry interval if the handshake did not fail due + to the other end closing the connection */ + if (r != GNUTLS_E_PULL_ERROR && r != GNUTLS_E_PREMATURE_TERMINATION) + return TLS_FAILED; + + return TLS_CLOSED; + } + + return gnutls_record_get_direction(inst->session) ? TLS_AGAIN_OUTPUT : TLS_AGAIN_INPUT; + } + + if (DEBUG) { + char *description = gnutls_session_get_desc(inst->session); + DEBUG_LOG("Handshake with %s completed %s", inst->server_name, + description ? description : ""); + gnutls_free(description); + } + + if (!check_alpn(inst)) { + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE not supported by %s", inst->server_name); + return TLS_FAILED; + } + + return TLS_SUCCESS; +} + +/* ================================================== */ + +TLS_Status +TLS_Send(TLS_Instance inst, const void *data, int length, int *sent) +{ + int r; + + if (length < 0) + return TLS_FAILED; + + r = gnutls_record_send(inst->session, data, length); + + if (r < 0) { + if (gnutls_error_is_fatal(r)) { + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, + "Could not send NTS-KE message to %s : %s", inst->server_name, gnutls_strerror(r)); + return TLS_FAILED; + } + + return TLS_AGAIN_OUTPUT; + } + + *sent = r; + + return TLS_SUCCESS; +} + +/* ================================================== */ + +TLS_Status +TLS_Receive(TLS_Instance inst, void *data, int length, int *received) +{ + int r; + + if (length < 0) + return TLS_FAILED; + + r = gnutls_record_recv(inst->session, data, length); + + if (r < 0) { + /* Handle a renegotiation request on both client and server as + a protocol error */ + if (gnutls_error_is_fatal(r) || r == GNUTLS_E_REHANDSHAKE) { + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, + "Could not receive NTS-KE message from %s : %s", + inst->server_name, gnutls_strerror(r)); + return TLS_FAILED; + } + + return TLS_AGAIN_INPUT; + } + + *received = r; + + return TLS_SUCCESS; +} + +/* ================================================== */ + +int +TLS_CheckPending(TLS_Instance inst) +{ + return gnutls_record_check_pending(inst->session) > 0; +} + +/* ================================================== */ + +TLS_Status +TLS_Shutdown(TLS_Instance inst) +{ + int r = gnutls_bye(inst->session, GNUTLS_SHUT_RDWR); + + if (r < 0) { + if (gnutls_error_is_fatal(r)) { + DEBUG_LOG("Shutdown with %s failed : %s", inst->server_name, gnutls_strerror(r)); + return TLS_FAILED; + } + + return gnutls_record_get_direction(inst->session) ? TLS_AGAIN_OUTPUT : TLS_AGAIN_INPUT; + } + + return TLS_SUCCESS; +} + +/* ================================================== */ + +int +TLS_ExportKey(TLS_Instance inst, int label_length, const char *label, int context_length, + const void *context, int key_length, unsigned char *key) +{ + int r; + + if (label_length < 0 || context_length < 0 || key_length < 0) + return 0; + + r = gnutls_prf_rfc5705(inst->session, label_length, label, context_length, (char *)context, + key_length, (char *)key); + + return r >= 0; +}