From 9a5ac06921357bdfd4e2f74b5b32955464bf9b75 Mon Sep 17 00:00:00 2001 From: Neil Horman Date: Thu, 29 May 2025 15:26:02 +0100 Subject: [PATCH] Add test for yielding of write secrets before read Test that, in QUIC, we yield write secrets before read secrets Reviewed-by: Tomas Mraz Reviewed-by: Matt Caswell (Merged from https://github.com/openssl/openssl/pull/27732) --- test/helpers/ssltestlib.c | 34 ++++++++-- test/helpers/ssltestlib.h | 4 ++ test/sslapitest.c | 129 +++++++++++++++++++++++++++++++++++++- 3 files changed, 160 insertions(+), 7 deletions(-) diff --git a/test/helpers/ssltestlib.c b/test/helpers/ssltestlib.c index 10618905c4c..56d526f7525 100644 --- a/test/helpers/ssltestlib.c +++ b/test/helpers/ssltestlib.c @@ -1258,12 +1258,13 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl, * has SSL_get_error() return the value in the |want| parameter. The connection * attempt could be restarted by a subsequent call to this function. */ -int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, - int read, int listen) +int create_bare_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want, + int read, int listen, int *cm_count, int *sm_count) { int retc = -1, rets = -1, err, abortctr = 0, ret = 0; int clienterr = 0, servererr = 0; int isdtls = SSL_is_dtls(serverssl); + int icm_count = 0, ism_count = 0; #ifndef OPENSSL_NO_SOCK BIO_ADDR *peer = NULL; @@ -1289,6 +1290,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, retc = SSL_connect(clientssl); if (retc <= 0) err = SSL_get_error(clientssl, retc); + icm_count++; } if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) { @@ -1314,12 +1316,14 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, listen = 0; rets = 0; } + ism_count++; } else #endif { rets = SSL_accept(serverssl); if (rets <= 0) err = SSL_get_error(serverssl, rets); + ism_count++; } } @@ -1345,6 +1349,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, TEST_info("Unexpected SSL_read() success!"); goto err; } + ism_count++; } if (retc > 0 && rets <= 0) { if (SSL_read(clientssl, buf, sizeof(buf)) > 0) { @@ -1352,6 +1357,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, TEST_info("Unexpected SSL_read() success!"); goto err; } + icm_count++; } } if (++abortctr == MAXLOOPS) { @@ -1370,23 +1376,36 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, ret = 1; err: + if (cm_count != NULL) + *cm_count = icm_count; + if (sm_count != NULL) + *sm_count = ism_count; #ifndef OPENSSL_NO_SOCK BIO_ADDR_free(peer); #endif return ret; } +int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, + int read, int listen) +{ + return create_bare_ssl_connection_ex(serverssl, clientssl, want, read, + listen, NULL, NULL); +} + /* * Create an SSL connection including any post handshake NewSessionTicket * messages. */ -int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want) +int create_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want, + int *cm_count, int *sm_count) { int i; unsigned char buf; size_t readbytes; - if (!create_bare_ssl_connection(serverssl, clientssl, want, 1, 0)) + if (!create_bare_ssl_connection_ex(serverssl, clientssl, want, 1, 0, + cm_count, sm_count)) return 0; /* @@ -1402,11 +1421,18 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want) SSL_ERROR_WANT_READ)) { return 0; } + if (cm_count != NULL) + (*cm_count)++; } return 1; } +int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want) +{ + return create_ssl_connection_ex(serverssl, clientssl, want, NULL, NULL); +} + void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl) { SSL_shutdown(clientssl); diff --git a/test/helpers/ssltestlib.h b/test/helpers/ssltestlib.h index 16f679cf5f0..b5606041a79 100644 --- a/test/helpers/ssltestlib.h +++ b/test/helpers/ssltestlib.h @@ -28,11 +28,15 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl, SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio); int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want, int read, int listen); +int create_bare_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want, + int read, int listen, int *cm_count, int *sm_count); int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl, SSL **cssl, int sfd, int cfd); int wait_until_sock_readable(int sock); int create_test_sockets(int *cfdp, int *sfdp, int socktype, BIO_ADDR *saddr); int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want); +int create_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want, + int *cm_count, int *sm_count); void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl); /* Note: Not thread safe! */ diff --git a/test/sslapitest.c b/test/sslapitest.c index 207382e38b3..f68f12dac50 100644 --- a/test/sslapitest.c +++ b/test/sslapitest.c @@ -12636,6 +12636,7 @@ struct quic_tls_test_data { int alert; int err; int forcefail; + int sm_count; }; static int clientquicdata = 0xff, serverquicdata = 0xfe; @@ -12723,6 +12724,94 @@ static int crypto_release_rcd_cb(SSL *s, size_t bytes_read, void *arg) return 1; } +struct secret_yield_entry { + uint8_t recorded; + int prot_level; + int direction; + int sm_generation; + SSL *ssl; +}; + +static struct secret_yield_entry secret_history[16]; +static int secret_history_idx = 0; +/* + * Note, this enum needs to match the direction values passed + * to yield_secret_cb + */ +typedef enum { + LAST_DIR_READ = 0, + LAST_DIR_WRITE = 1, + LAST_DIR_UNSET = 2 +} last_dir_history_state; + +static int check_secret_history(SSL *s) +{ + int i; + int ret = 0; + last_dir_history_state last_state = LAST_DIR_UNSET; + int last_prot_level = 0; + int last_generation = 0; + + TEST_info("Checking history for %p\n", (void *)s); + for (i = 0; secret_history[i].recorded == 1; i++) { + if (secret_history[i].ssl != s) + continue; + TEST_info("Got %s(%d) secret for level %d, last level %d, last state %d, gen %d\n", + secret_history[i].direction == 1 ? "Write" : "Read", secret_history[i].direction, + secret_history[i].prot_level, last_prot_level, last_state, + secret_history[i].sm_generation); + + if (last_state == LAST_DIR_UNSET) { + last_prot_level = secret_history[i].prot_level; + last_state = secret_history[i].direction; + last_generation = secret_history[i].sm_generation; + continue; + } + + switch(secret_history[i].direction) { + case 1: + /* + * write case + * NOTE: There is an odd corner case here. It may occur that + * in a single iteration of the state machine, the read key is yielded + * prior to the write key for the same level. This is undesireable + * for quic, but it is ok, as the general implementation of every 3rd + * party quic stack while prefering write keys before read, allows + * for read before write if both keys are yielded in the same call + * to SSL_do_handshake, as the tls adaptation code for that quic stack + * can then cache keys until both are available, so we allow read before + * write here iff they occur in the same iteration of SSL_do_handshake + * as represented by the recorded sm_generation value. + */ + if (last_prot_level == secret_history[i].prot_level + && last_state == LAST_DIR_READ) { + if (last_generation == secret_history[i].sm_generation) { + TEST_info("Read before write key in same SSL state machine iteration is ok"); + } else { + TEST_error("Got read key before write key"); + goto end; + } + } + /* FALLTHROUGH */ + case 0: + /* + * Read case + */ + break; + default: + TEST_error("Unknown direction"); + goto end; + } + last_prot_level = secret_history[i].prot_level; + last_state = secret_history[i].direction; + last_generation = secret_history[i].sm_generation; + } + + ret = 1; +end: + return ret; +} + static int yield_secret_cb(SSL *s, uint32_t prot_level, int direction, const unsigned char *secret, size_t secret_len, void *arg) @@ -12757,6 +12846,12 @@ static int yield_secret_cb(SSL *s, uint32_t prot_level, int direction, goto err; } + secret_history[secret_history_idx].direction = direction; + secret_history[secret_history_idx].prot_level = (int)prot_level; + secret_history[secret_history_idx].recorded = 1; + secret_history[secret_history_idx].ssl = s; + secret_history[secret_history_idx].sm_generation = data->sm_count; + secret_history_idx++; return 1; err: data->err = 1; @@ -12851,6 +12946,8 @@ static int test_quic_tls(int idx) if (idx == 4) qtdis[3].function = (void (*)(void))yield_secret_cb_fail; + memset(secret_history, 0, sizeof(secret_history)); + secret_history_idx = 0; memset(&sdata, 0, sizeof(sdata)); memset(&cdata, 0, sizeof(cdata)); sdata.peer = &cdata; @@ -12890,11 +12987,13 @@ static int test_quic_tls(int idx) goto end; if (idx != 1 && idx != 4) { - if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) + if (!TEST_true(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE, + &cdata.sm_count, &sdata.sm_count))) goto end; } else { /* We expect this connection to fail */ - if (!TEST_false(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) + if (!TEST_false(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE, + &cdata.sm_count, &sdata.sm_count))) goto end; testresult = 1; sdata.err = 0; @@ -12917,6 +13016,14 @@ static int test_quic_tls(int idx) goto end; } + /* + * Check that our secret history yields write secrets before read secrets + */ + if (!TEST_int_eq(check_secret_history(serverssl), 1)) + goto end; + if (!TEST_int_eq(check_secret_history(clientssl), 1)) + goto end; + /* Check the transport params */ if (!TEST_mem_eq(sdata.params, sdata.params_len, cparams, sizeof(cparams)) || !TEST_mem_eq(cdata.params, cdata.params_len, sparams, @@ -12981,6 +13088,8 @@ static int test_quic_tls_early_data(void) }; int i; + memset(secret_history, 0, sizeof(secret_history)); + secret_history_idx = 0; memset(&sdata, 0, sizeof(sdata)); memset(&cdata, 0, sizeof(cdata)); sdata.peer = &cdata; @@ -13030,6 +13139,12 @@ static int test_quic_tls_early_data(void) sizeof(sparams)))) goto end; + /* + * Reset our secret history so we get the record of the second connection + */ + memset(secret_history, 0, sizeof(secret_history)); + secret_history_idx = 0; + SSL_set_quic_tls_early_data_enabled(serverssl, 1); SSL_set_quic_tls_early_data_enabled(clientssl, 1); @@ -13050,7 +13165,10 @@ static int test_quic_tls_early_data(void) || !TEST_true(cdata.wenc_level == OSSL_RECORD_PROTECTION_LEVEL_EARLY)) goto end; - if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) + sdata.sm_count = 0; + cdata.sm_count = 0; + if (!TEST_true(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE, + &cdata.sm_count, &sdata.sm_count))) goto end; /* Check no problems during the handshake */ @@ -13069,6 +13187,11 @@ static int test_quic_tls_early_data(void) goto end; } + if (!TEST_int_eq(check_secret_history(serverssl), 1)) + goto end; + if (!TEST_int_eq(check_secret_history(clientssl), 1)) + goto end; + /* Check the transport params */ if (!TEST_mem_eq(sdata.params, sdata.params_len, cparams, sizeof(cparams)) || !TEST_mem_eq(cdata.params, cdata.params_len, sparams, -- 2.47.2