]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Add test for yielding of write secrets before read
authorNeil Horman <nhorman@openssl.org>
Thu, 29 May 2025 14:26:02 +0000 (15:26 +0100)
committerMatt Caswell <matt@openssl.org>
Tue, 3 Jun 2025 16:06:41 +0000 (17:06 +0100)
Test that, in QUIC, we yield write secrets before read secrets

Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/27732)

test/helpers/ssltestlib.c
test/helpers/ssltestlib.h
test/sslapitest.c

index 10618905c4c41e00c186171fa1c89500859ff178..56d526f7525e0aca0072235b825528a84d7aca42 100644 (file)
@@ -1258,12 +1258,13 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
  * has SSL_get_error() return the value in the |want| parameter. The connection
  * attempt could be restarted by a subsequent call to this function.
  */
-int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
-                               int read, int listen)
+int create_bare_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want,
+                                  int read, int listen, int *cm_count, int *sm_count)
 {
     int retc = -1, rets = -1, err, abortctr = 0, ret = 0;
     int clienterr = 0, servererr = 0;
     int isdtls = SSL_is_dtls(serverssl);
+    int icm_count = 0, ism_count = 0;
 #ifndef OPENSSL_NO_SOCK
     BIO_ADDR *peer = NULL;
 
@@ -1289,6 +1290,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
             retc = SSL_connect(clientssl);
             if (retc <= 0)
                 err = SSL_get_error(clientssl, retc);
+            icm_count++;
         }
 
         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
@@ -1314,12 +1316,14 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
                     listen = 0;
                     rets = 0;
                 }
+                ism_count++;
             } else
 #endif
             {
                 rets = SSL_accept(serverssl);
                 if (rets <= 0)
                     err = SSL_get_error(serverssl, rets);
+                ism_count++;
             }
         }
 
@@ -1345,6 +1349,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
                     TEST_info("Unexpected SSL_read() success!");
                     goto err;
                 }
+                ism_count++;
             }
             if (retc > 0 && rets <= 0) {
                 if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
@@ -1352,6 +1357,7 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
                     TEST_info("Unexpected SSL_read() success!");
                     goto err;
                 }
+                icm_count++;
             }
         }
         if (++abortctr == MAXLOOPS) {
@@ -1370,23 +1376,36 @@ int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
 
     ret = 1;
  err:
+    if (cm_count != NULL)
+        *cm_count = icm_count;
+    if (sm_count != NULL)
+        *sm_count = ism_count;
 #ifndef OPENSSL_NO_SOCK
     BIO_ADDR_free(peer);
 #endif
     return ret;
 }
 
+int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
+                               int read, int listen)
+{
+    return create_bare_ssl_connection_ex(serverssl, clientssl, want, read,
+                                         listen, NULL, NULL);
+}
+
 /*
  * Create an SSL connection including any post handshake NewSessionTicket
  * messages.
  */
-int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+int create_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want,
+                             int *cm_count, int *sm_count)
 {
     int i;
     unsigned char buf;
     size_t readbytes;
 
-    if (!create_bare_ssl_connection(serverssl, clientssl, want, 1, 0))
+    if (!create_bare_ssl_connection_ex(serverssl, clientssl, want, 1, 0,
+                                       cm_count, sm_count))
         return 0;
 
     /*
@@ -1402,11 +1421,18 @@ int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
                                 SSL_ERROR_WANT_READ)) {
             return 0;
         }
+        if (cm_count != NULL)
+            (*cm_count)++;
     }
 
     return 1;
 }
 
+int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
+{
+   return create_ssl_connection_ex(serverssl, clientssl, want, NULL, NULL);
+}
+
 void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
 {
     SSL_shutdown(clientssl);
index 16f679cf5f0f9db8117728286d8342ee4cdb328d..b5606041a799f9d97ed631688deecea791519144 100644 (file)
@@ -28,11 +28,15 @@ int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
                        SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio);
 int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
                                int read, int listen);
+int create_bare_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want,
+                                  int read, int listen, int *cm_count, int *sm_count);
 int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
                        SSL **cssl, int sfd, int cfd);
 int wait_until_sock_readable(int sock);
 int create_test_sockets(int *cfdp, int *sfdp, int socktype, BIO_ADDR *saddr);
 int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want);
+int create_ssl_connection_ex(SSL *serverssl, SSL *clientssl, int want,
+                             int *cm_count, int *sm_count);
 void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl);
 
 /* Note: Not thread safe! */
index 207382e38b369c7ee1725a0d0a031bbba842beb8..f68f12dac50304a19b30cbf98f0978c162edf8f2 100644 (file)
@@ -12636,6 +12636,7 @@ struct quic_tls_test_data {
     int alert;
     int err;
     int forcefail;
+    int sm_count;
 };
 
 static int clientquicdata = 0xff, serverquicdata = 0xfe;
@@ -12723,6 +12724,94 @@ static int crypto_release_rcd_cb(SSL *s, size_t bytes_read, void *arg)
     return 1;
 }
 
+struct secret_yield_entry {
+    uint8_t recorded;
+    int prot_level;
+    int direction;
+    int sm_generation;
+    SSL *ssl;
+};
+
+static struct secret_yield_entry secret_history[16];
+static int secret_history_idx = 0;
+/*
+ * Note, this enum needs to match the direction values passed
+ * to yield_secret_cb
+ */
+typedef enum {
+    LAST_DIR_READ = 0,
+    LAST_DIR_WRITE = 1,
+    LAST_DIR_UNSET = 2
+} last_dir_history_state;
+
+static int check_secret_history(SSL *s)
+{
+    int i;
+    int ret = 0;
+    last_dir_history_state last_state = LAST_DIR_UNSET;
+    int last_prot_level = 0;
+    int last_generation = 0;
+
+    TEST_info("Checking history for %p\n", (void *)s);
+    for (i = 0; secret_history[i].recorded == 1; i++) {
+        if (secret_history[i].ssl != s)
+            continue;
+        TEST_info("Got %s(%d) secret for level %d, last level %d, last state %d, gen %d\n",
+                  secret_history[i].direction == 1 ? "Write" : "Read", secret_history[i].direction,
+                  secret_history[i].prot_level, last_prot_level, last_state,
+                  secret_history[i].sm_generation);
+
+        if (last_state == LAST_DIR_UNSET) {
+            last_prot_level = secret_history[i].prot_level;
+            last_state = secret_history[i].direction;
+            last_generation = secret_history[i].sm_generation;
+            continue;
+        }
+
+        switch(secret_history[i].direction) {
+        case 1:
+            /*
+             * write case
+             * NOTE: There is an odd corner case here.  It may occur that
+             * in a single iteration of the state machine, the read key is yielded
+             * prior to the write key for the same level.  This is undesireable
+             * for quic, but it is ok, as the general implementation of every 3rd
+             * party quic stack while prefering write keys before read, allows
+             * for read before write if both keys are yielded in the same call
+             * to SSL_do_handshake, as the tls adaptation code for that quic stack
+             * can then cache keys until both are available, so we allow read before
+             * write here iff they occur in the same iteration of SSL_do_handshake
+             * as represented by the recorded sm_generation value.
+             */
+            if (last_prot_level == secret_history[i].prot_level
+                && last_state == LAST_DIR_READ) {
+                if (last_generation == secret_history[i].sm_generation) {
+                    TEST_info("Read before write key in same SSL state machine iteration is ok");
+                } else {
+                    TEST_error("Got read key before write key");
+                    goto end;
+                }
+            }
+            /* FALLTHROUGH */
+        case 0:
+            /*
+             * Read case
+             */
+            break;
+        default:
+            TEST_error("Unknown direction");
+            goto end;
+        }
+        last_prot_level = secret_history[i].prot_level;
+        last_state = secret_history[i].direction;
+        last_generation = secret_history[i].sm_generation;
+    }
+
+    ret = 1;
+end:
+    return ret;
+}
+
 static int yield_secret_cb(SSL *s, uint32_t prot_level, int direction,
                            const unsigned char *secret, size_t secret_len,
                            void *arg)
@@ -12757,6 +12846,12 @@ static int yield_secret_cb(SSL *s, uint32_t prot_level, int direction,
         goto err;
     }
 
+    secret_history[secret_history_idx].direction = direction;
+    secret_history[secret_history_idx].prot_level = (int)prot_level;
+    secret_history[secret_history_idx].recorded = 1;
+    secret_history[secret_history_idx].ssl = s;
+    secret_history[secret_history_idx].sm_generation = data->sm_count;
+    secret_history_idx++;
     return 1;
  err:
     data->err = 1;
@@ -12851,6 +12946,8 @@ static int test_quic_tls(int idx)
     if (idx == 4)
         qtdis[3].function = (void (*)(void))yield_secret_cb_fail;
 
+    memset(secret_history, 0, sizeof(secret_history));
+    secret_history_idx = 0;
     memset(&sdata, 0, sizeof(sdata));
     memset(&cdata, 0, sizeof(cdata));
     sdata.peer = &cdata;
@@ -12890,11 +12987,13 @@ static int test_quic_tls(int idx)
         goto end;
 
     if (idx != 1 && idx != 4) {
-        if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
+        if (!TEST_true(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE,
+                                                &cdata.sm_count, &sdata.sm_count)))
             goto end;
     } else {
         /* We expect this connection to fail */
-        if (!TEST_false(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
+        if (!TEST_false(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE,
+                                                 &cdata.sm_count, &sdata.sm_count)))
             goto end;
         testresult = 1;
         sdata.err = 0;
@@ -12917,6 +13016,14 @@ static int test_quic_tls(int idx)
             goto end;
     }
 
+    /*
+     * Check that our secret history yields write secrets before read secrets
+     */
+    if (!TEST_int_eq(check_secret_history(serverssl), 1))
+        goto end;
+    if (!TEST_int_eq(check_secret_history(clientssl), 1))
+        goto end;
+
     /* Check the transport params */
     if (!TEST_mem_eq(sdata.params, sdata.params_len, cparams, sizeof(cparams))
             || !TEST_mem_eq(cdata.params, cdata.params_len, sparams,
@@ -12981,6 +13088,8 @@ static int test_quic_tls_early_data(void)
     };
     int i;
 
+    memset(secret_history, 0, sizeof(secret_history));
+    secret_history_idx = 0;
     memset(&sdata, 0, sizeof(sdata));
     memset(&cdata, 0, sizeof(cdata));
     sdata.peer = &cdata;
@@ -13030,6 +13139,12 @@ static int test_quic_tls_early_data(void)
                                                             sizeof(sparams))))
         goto end;
 
+    /*
+     * Reset our secret history so we get the record of the second connection
+     */
+    memset(secret_history, 0, sizeof(secret_history));
+    secret_history_idx = 0;
+
     SSL_set_quic_tls_early_data_enabled(serverssl, 1);
     SSL_set_quic_tls_early_data_enabled(clientssl, 1);
 
@@ -13050,7 +13165,10 @@ static int test_quic_tls_early_data(void)
             || !TEST_true(cdata.wenc_level == OSSL_RECORD_PROTECTION_LEVEL_EARLY))
         goto end;
 
-    if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
+    sdata.sm_count = 0;
+    cdata.sm_count = 0;
+    if (!TEST_true(create_ssl_connection_ex(serverssl, clientssl, SSL_ERROR_NONE,
+                                            &cdata.sm_count, &sdata.sm_count)))
         goto end;
 
     /* Check no problems during the handshake */
@@ -13069,6 +13187,11 @@ static int test_quic_tls_early_data(void)
             goto end;
     }
 
+    if (!TEST_int_eq(check_secret_history(serverssl), 1))
+        goto end;
+    if (!TEST_int_eq(check_secret_history(clientssl), 1))
+        goto end;
+
     /* Check the transport params */
     if (!TEST_mem_eq(sdata.params, sdata.params_len, cparams, sizeof(cparams))
             || !TEST_mem_eq(cdata.params, cdata.params_len, sparams,