]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Add a test for the quic-tls API
authorMatt Caswell <matt@openssl.org>
Wed, 21 Aug 2024 13:50:55 +0000 (14:50 +0100)
committerMatt Caswell <matt@openssl.org>
Tue, 11 Feb 2025 17:17:10 +0000 (17:17 +0000)
Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Neil Horman <nhorman@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/26683)

test/sslapitest.c

index a6b979bceb5ff6537333b27007401e3782e4f576..9428c1f59ddfe5525dc5caeaf94fb3adf16088aa 100644 (file)
@@ -12403,6 +12403,227 @@ static int test_alpn(int idx)
     return testresult;
 }
 
+#if !defined(OPENSSL_NO_QUIC) && !defined(OSSL_NO_USABLE_TLS1_3)
+struct quic_tls_test_data {
+    struct quic_tls_test_data *peer;
+    uint32_t renc_level;
+    uint32_t wenc_level;
+    unsigned char rcd_data[4][2048];
+    size_t rcd_data_len[4];
+    unsigned char rsecret[3][48];
+    size_t rsecret_len[3];
+    unsigned char wsecret[3][48];
+    size_t wsecret_len[3];
+    unsigned char params[3];
+    size_t params_len;
+    int alert;
+    int err;
+};
+
+static int crypto_send_cb(SSL *s, const unsigned char *buf, size_t buf_len,
+                          size_t *consumed, void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+    struct quic_tls_test_data *peer = data->peer;
+    size_t max_len = sizeof(peer->rcd_data[data->wenc_level])
+                     - peer->rcd_data_len[data->wenc_level];
+
+    if (buf_len > max_len)
+        buf_len = max_len;
+
+    if (buf_len == 0) {
+        *consumed = 0;
+        return 1;
+    }
+
+    memcpy(peer->rcd_data[data->wenc_level]
+           + peer->rcd_data_len[data->wenc_level], buf, buf_len);
+    peer->rcd_data_len[data->wenc_level] += buf_len;
+
+    *consumed = buf_len;
+    return 1;
+}
+static int crypto_recv_rcd_cb(SSL *s, const unsigned char **buf,
+                              size_t *bytes_read, void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+
+    *bytes_read = data->rcd_data_len[data->renc_level];
+    *buf = data->rcd_data[data->renc_level];
+    return 1;
+}
+
+static int crypto_release_rcd_cb(SSL *s, size_t bytes_read, void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+
+    if (!TEST_size_t_eq(bytes_read, data->rcd_data_len[data->renc_level])
+            || !TEST_size_t_gt(bytes_read, 0)) {
+        data->err = 1;
+        return 0;
+    }
+    data->rcd_data_len[data->renc_level] = 0;
+
+    return 1;
+}
+
+static int yield_secret_cb(SSL *s, uint32_t prot_level, int direction,
+                           const unsigned char *secret, size_t secret_len,
+                           void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+
+    if (prot_level < OSSL_RECORD_PROTECTION_LEVEL_EARLY
+            || prot_level > OSSL_RECORD_PROTECTION_LEVEL_APPLICATION)
+        goto err;
+
+    switch (direction) {
+    case 0: /* read */
+        if (!TEST_size_t_le(secret_len, sizeof(data->rsecret)))
+            goto err;
+        data->renc_level = prot_level;
+        memcpy(data->rsecret[prot_level - 1], secret, secret_len);
+        data->rsecret_len[prot_level - 1] = secret_len;
+        break;
+
+    case 1: /* write */
+        if (!TEST_size_t_le(secret_len, sizeof(data->wsecret)))
+            goto err;
+        data->wenc_level = prot_level;
+        memcpy(data->wsecret[prot_level - 1], secret, secret_len);
+        data->wsecret_len[prot_level - 1] = secret_len;
+        break;
+
+    default:
+        goto err;
+    }
+
+    return 1;
+ err:
+    data->err = 1;
+    return 0;
+}
+
+static int got_transport_params_cb(SSL *s, const unsigned char *params,
+                                   size_t params_len,
+                                   void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+
+    if (!TEST_size_t_le(params_len, sizeof(data->params))) {
+        data->err = 1;
+        return 0;
+    }
+
+    memcpy(data->params, params, params_len);
+    data->params_len = params_len;
+
+    return 1;
+}
+
+static int alert_cb(SSL *s, unsigned char alert_code, void *arg)
+{
+    struct quic_tls_test_data *data = (struct quic_tls_test_data *)arg;
+
+    data->alert = 1;
+    return 1;
+}
+
+/*
+ * Test the QUIC TLS API
+ */
+static int test_quic_tls(void)
+{
+    SSL_CTX *sctx = NULL, *cctx = NULL;
+    SSL *serverssl = NULL, *clientssl = NULL;
+    int testresult = 0;
+    const OSSL_DISPATCH qtdis[] = {
+        {OSSL_FUNC_SSL_QUIC_TLS_CRYPTO_SEND, (void (*)(void))crypto_send_cb},
+        {OSSL_FUNC_SSL_QUIC_TLS_CRYPTO_RECV_RCD,
+         (void (*)(void))crypto_recv_rcd_cb},
+        {OSSL_FUNC_SSL_QUIC_TLS_CRYPTO_RELEASE_RCD,
+         (void (*)(void))crypto_release_rcd_cb},
+        {OSSL_FUNC_SSL_QUIC_TLS_YIELD_SECRET,
+         (void (*)(void))yield_secret_cb},
+        {OSSL_FUNC_SSL_QUIC_TLS_GOT_TRANSPORT_PARAMS,
+         (void (*)(void))got_transport_params_cb},
+        {OSSL_FUNC_SSL_QUIC_TLS_ALERT, (void (*)(void))alert_cb},
+        {0, NULL}
+    };
+    struct quic_tls_test_data sdata, cdata;
+    const unsigned char cparams[] = {
+        0xff, 0x01, 0x00
+    };
+    const unsigned char sparams[] = {
+        0xfe, 0x01, 0x00
+    };
+    int i;
+
+    memset(&sdata, 0, sizeof(sdata));
+    memset(&cdata, 0, sizeof(cdata));
+    sdata.peer = &cdata;
+    cdata.peer = &sdata;
+
+    if (!TEST_true(create_ssl_ctx_pair(libctx, TLS_server_method(),
+                                       TLS_client_method(), TLS1_3_VERSION, 0,
+                                       &sctx, &cctx, cert, privkey)))
+        goto end;
+
+    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL,
+                                      NULL)))
+        goto end;
+
+    if (!TEST_true(SSL_set_quic_tls_cbs(clientssl, qtdis, &cdata))
+            || !TEST_true(SSL_set_quic_tls_cbs(serverssl, qtdis, &sdata))
+            || !TEST_true(SSL_set_quic_tls_transport_params(clientssl, cparams,
+                                                            sizeof(cparams)))
+            || !TEST_true(SSL_set_quic_tls_transport_params(serverssl, sparams,
+                                                            sizeof(sparams))))
+        goto end;
+
+    if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)))
+        goto end;
+
+    /* Check no problems during the handshake */
+    if (!TEST_false(sdata.alert)
+            || !TEST_false(cdata.alert)
+            || !TEST_false(sdata.err)
+            || !TEST_false(cdata.err))
+        goto end;
+
+    /* Check the secrets all match */
+    for (i = OSSL_RECORD_PROTECTION_LEVEL_EARLY - 1;
+         i < OSSL_RECORD_PROTECTION_LEVEL_APPLICATION;
+         i++) {
+        if (!TEST_mem_eq(sdata.wsecret[i], sdata.wsecret_len[i],
+                         cdata.rsecret[i], cdata.rsecret_len[i]))
+            goto end;
+    }
+
+    /* Check the transport params */
+    if (!TEST_mem_eq(sdata.params, sdata.params_len, cparams, sizeof(cparams))
+            || !TEST_mem_eq(cdata.params, cdata.params_len, sparams,
+                            sizeof(sparams)))
+        goto end;
+
+    /* Check the encryption levels are what we expect them to be */
+    if (!TEST_true(sdata.renc_level == OSSL_RECORD_PROTECTION_LEVEL_APPLICATION)
+            || !TEST_true(sdata.wenc_level == OSSL_RECORD_PROTECTION_LEVEL_APPLICATION)
+            || !TEST_true(cdata.renc_level == OSSL_RECORD_PROTECTION_LEVEL_APPLICATION)
+            || !TEST_true(sdata.wenc_level == OSSL_RECORD_PROTECTION_LEVEL_APPLICATION))
+        goto end;
+
+    testresult = 1;
+ end:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+
+    return testresult;
+}
+#endif /* !defined(OPENSSL_NO_QUIC) && !defined(OSSL_NO_USABLE_TLS1_3) */
+
 OPT_TEST_DECLARE_USAGE("certfile privkeyfile srpvfile tmpfile provider config dhfile\n")
 
 int setup_tests(void)
@@ -12725,6 +12946,9 @@ int setup_tests(void)
     ADD_ALL_TESTS(test_npn, 5);
 #endif
     ADD_ALL_TESTS(test_alpn, 4);
+#if !defined(OPENSSL_NO_QUIC) && !defined(OSSL_NO_USABLE_TLS1_3)
+    ADD_TEST(test_quic_tls);
+#endif
     return 1;
 
  err: