]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
QUIC TSERVER: Add support for multiple streams
authorHugo Landau <hlandau@openssl.org>
Tue, 18 Apr 2023 18:30:54 +0000 (19:30 +0100)
committerHugo Landau <hlandau@openssl.org>
Fri, 12 May 2023 13:47:11 +0000 (14:47 +0100)
Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/20765)

include/internal/quic_tserver.h
ssl/quic/quic_tserver.c
test/quic_tserver_test.c
test/quicapitest.c
test/quicfaultstest.c

index 0d0d2014974b2e8f84ff7c4ad57ce16fd3191372..fd657049abb2576849872c4969b360e866d50f98 100644 (file)
@@ -86,6 +86,7 @@ int ossl_quic_tserver_is_terminated(const QUIC_TSERVER *srv);
  * ossl_quic_tserver_has_read_ended() to identify this condition.
  */
 int ossl_quic_tserver_read(QUIC_TSERVER *srv,
+                           uint64_t stream_id,
                            unsigned char *buf,
                            size_t buf_len,
                            size_t *bytes_read);
@@ -93,7 +94,7 @@ int ossl_quic_tserver_read(QUIC_TSERVER *srv,
 /*
  * Returns 1 if the read part of the stream has ended normally.
  */
-int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv);
+int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv, uint64_t stream_id);
 
 /*
  * Attempts to write to stream 0. Writes the number of bytes consumed to
@@ -107,6 +108,7 @@ int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv);
  * Returns 0 if connection is not currently active.
  */
 int ossl_quic_tserver_write(QUIC_TSERVER *srv,
+                            uint64_t stream_id,
                             const unsigned char *buf,
                             size_t buf_len,
                             size_t *bytes_written);
@@ -114,7 +116,15 @@ int ossl_quic_tserver_write(QUIC_TSERVER *srv,
 /*
  * Signals normal end of the stream.
  */
-int ossl_quic_tserver_conclude(QUIC_TSERVER *srv);
+int ossl_quic_tserver_conclude(QUIC_TSERVER *srv, uint64_t stream_id);
+
+/*
+ * Create a server-initiated stream. The stream ID of the newly
+ * created stream is written to *stream_id.
+ */
+int ossl_quic_tserver_stream_new(QUIC_TSERVER *srv,
+                                 int is_uni,
+                                 uint64_t *stream_id);
 
 BIO *ossl_quic_tserver_get0_rbio(QUIC_TSERVER *srv);
 
index 498ea622380ad52b36a67fb829e1f5553b2174c8..12f45166082ea1f0e8054a61ce3412865e3dbed3 100644 (file)
@@ -33,9 +33,6 @@ struct quic_tserver_st {
     /* SSL for the underlying TLS connection */
     SSL *tls;
 
-    /* Our single bidirectional application data stream. */
-    QUIC_STREAM     *stream0;
-
     /* The current peer L4 address. AF_UNSPEC if we do not have a peer yet. */
     BIO_ADDR        cur_peer_addr;
 
@@ -104,10 +101,6 @@ QUIC_TSERVER *ossl_quic_tserver_new(const QUIC_TSERVER_ARGS *args,
         || !ossl_quic_channel_set_net_wbio(srv->ch, srv->args.net_wbio))
         goto err;
 
-    srv->stream0 = ossl_quic_channel_get_stream_by_id(srv->ch, 0);
-    if (srv->stream0 == NULL)
-        goto err;
-
     return srv;
 
 err:
@@ -193,19 +186,40 @@ int ossl_quic_tserver_is_handshake_confirmed(const QUIC_TSERVER *srv)
 }
 
 int ossl_quic_tserver_read(QUIC_TSERVER *srv,
+                           uint64_t stream_id,
                            unsigned char *buf,
                            size_t buf_len,
                            size_t *bytes_read)
 {
     int is_fin = 0;
+    QUIC_STREAM *qs;
 
     if (!ossl_quic_channel_is_active(srv->ch))
         return 0;
 
-    if (srv->stream0->recv_fin_retired)
+    qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch),
+                                        stream_id);
+    if (qs == NULL) {
+        int is_client_init
+            = ((stream_id & QUIC_STREAM_INITIATOR_MASK)
+               == QUIC_STREAM_INITIATOR_CLIENT);
+
+        /*
+         * A client-initiated stream might spontaneously come into existence, so
+         * allow trying to read on a client-initiated stream before it exists.
+         * Otherwise, fail.
+         */
+        if (!is_client_init)
+            return 0;
+
+        *bytes_read = 0;
+        return 1;
+    }
+
+    if (qs->recv_fin_retired || qs->rstream == NULL)
         return 0;
 
-    if (!ossl_quic_rstream_read(srv->stream0->rstream, buf, buf_len,
+    if (!ossl_quic_rstream_read(qs->rstream, buf, buf_len,
                                 bytes_read, &is_fin))
         return 0;
 
@@ -220,35 +234,47 @@ int ossl_quic_tserver_read(QUIC_TSERVER *srv,
 
         ossl_statm_get_rtt_info(ossl_quic_channel_get_statm(srv->ch), &rtt_info);
 
-        if (!ossl_quic_rxfc_on_retire(&srv->stream0->rxfc, *bytes_read,
+        if (!ossl_quic_rxfc_on_retire(&qs->rxfc, *bytes_read,
                                       rtt_info.smoothed_rtt))
             return 0;
     }
 
     if (is_fin)
-        srv->stream0->recv_fin_retired = 1;
+        qs->recv_fin_retired = 1;
 
     if (*bytes_read > 0)
-        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch),
-                                          srv->stream0);
+        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs);
 
     return 1;
 }
 
-int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv)
+int ossl_quic_tserver_has_read_ended(QUIC_TSERVER *srv, uint64_t stream_id)
 {
-    return srv->stream0->recv_fin_retired;
+    QUIC_STREAM *qs;
+
+    qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch),
+                                        stream_id);
+
+    return qs != NULL && qs->recv_fin_retired;
 }
 
 int ossl_quic_tserver_write(QUIC_TSERVER *srv,
+                            uint64_t stream_id,
                             const unsigned char *buf,
                             size_t buf_len,
                             size_t *bytes_written)
 {
+    QUIC_STREAM *qs;
+
     if (!ossl_quic_channel_is_active(srv->ch))
         return 0;
 
-    if (!ossl_quic_sstream_append(srv->stream0->sstream,
+    qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch),
+                                        stream_id);
+    if (qs == NULL || qs->sstream == NULL)
+        return 0;
+
+    if (!ossl_quic_sstream_append(qs->sstream,
                                   buf, buf_len, bytes_written))
         return 0;
 
@@ -257,29 +283,50 @@ int ossl_quic_tserver_write(QUIC_TSERVER *srv,
          * We have appended at least one byte to the stream. Potentially mark
          * the stream as active, depending on FC.
          */
-        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch),
-                                          srv->stream0);
+        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs);
 
     /* Try and send. */
     ossl_quic_tserver_tick(srv);
     return 1;
 }
 
-int ossl_quic_tserver_conclude(QUIC_TSERVER *srv)
+int ossl_quic_tserver_conclude(QUIC_TSERVER *srv, uint64_t stream_id)
 {
+    QUIC_STREAM *qs;
+
     if (!ossl_quic_channel_is_active(srv->ch))
         return 0;
 
-    if (!ossl_quic_sstream_get_final_size(srv->stream0->sstream, NULL)) {
-        ossl_quic_sstream_fin(srv->stream0->sstream);
-        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch),
-                                          srv->stream0);
+    qs = ossl_quic_stream_map_get_by_id(ossl_quic_channel_get_qsm(srv->ch),
+                                        stream_id);
+    if  (qs == NULL || qs->sstream == NULL)
+        return 0;
+
+    if (!ossl_quic_sstream_get_final_size(qs->sstream, NULL)) {
+        ossl_quic_sstream_fin(qs->sstream);
+        ossl_quic_stream_map_update_state(ossl_quic_channel_get_qsm(srv->ch), qs);
     }
 
     ossl_quic_tserver_tick(srv);
     return 1;
 }
 
+int ossl_quic_tserver_stream_new(QUIC_TSERVER *srv,
+                                 int is_uni,
+                                 uint64_t *stream_id)
+{
+    QUIC_STREAM *qs;
+
+    if (!ossl_quic_channel_is_active(srv->ch))
+        return 0;
+
+    if ((qs = ossl_quic_channel_new_stream_local(srv->ch, is_uni)) == NULL)
+        return 0;
+
+    *stream_id = qs->id;
+    return 1;
+}
+
 BIO *ossl_quic_tserver_get0_rbio(QUIC_TSERVER *srv)
 {
     return srv->args.net_rbio;
index a385381716b8d4b250198f772daa164177ab5415..e9ae4703b2069c96cc4969c6e2a3b7467bbcd6b4 100644 (file)
@@ -215,16 +215,17 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject)
         }
 
         if (c_connected && c_write_done && !s_read_done) {
-            if (!ossl_quic_tserver_read(tserver,
+            if (!ossl_quic_tserver_read(tserver, 0,
                                         (unsigned char *)msg2 + s_total_read,
                                         sizeof(msg2) - s_total_read, &l)) {
-                if (!TEST_true(ossl_quic_tserver_has_read_ended(tserver)))
+                if (!TEST_true(ossl_quic_tserver_has_read_ended(tserver, 0)))
                     goto err;
 
                 if (!TEST_mem_eq(msg1, sizeof(msg1) - 1, msg2, s_total_read))
                     goto err;
 
                 s_begin_write = 1;
+                s_read_done   = 1;
             } else {
                 s_total_read += l;
                 if (!TEST_size_t_le(s_total_read, sizeof(msg1) - 1))
@@ -233,7 +234,7 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject)
         }
 
         if (s_begin_write && s_total_written < sizeof(msg1) - 1) {
-            if (!TEST_true(ossl_quic_tserver_write(tserver,
+            if (!TEST_true(ossl_quic_tserver_write(tserver, 0,
                                                    (unsigned char *)msg2 + s_total_written,
                                                    sizeof(msg1) - 1 - s_total_written, &l)))
                 goto err;
@@ -241,7 +242,7 @@ static int do_test(int use_thread_assist, int use_fake_time, int use_inject)
             s_total_written += l;
 
             if (s_total_written == sizeof(msg1) - 1) {
-                ossl_quic_tserver_conclude(tserver);
+                ossl_quic_tserver_conclude(tserver, 0);
                 c_begin_read = 1;
             }
         }
index 092e303ba65b0fb14117199de07d55ca82b29c93..3ce695e5e65d6decd37e17e4c238b8e8cc8070d4 100644 (file)
@@ -42,6 +42,7 @@ static int test_quic_write_read(int idx)
     size_t msglen = strlen(msg);
     size_t numbytes = 0;
     int ssock = 0, csock = 0;
+    uint64_t sid = UINT64_MAX;
 
     if (idx == 1 && !qtest_supports_blocking())
         return TEST_skip("Blocking tests not supported in this build");
@@ -61,6 +62,10 @@ static int test_quic_write_read(int idx)
             goto end;
     }
 
+    if (!TEST_true(ossl_quic_tserver_stream_new(qtserv, /*is_uni=*/0, &sid))
+        || !TEST_uint64_t_eq(sid, 1)) /* server-initiated, so first SID is 1 */
+        goto end;
+
     for (j = 0; j < 2; j++) {
         /* Check that sending and receiving app data is ok */
         if (!TEST_true(SSL_write_ex(clientquic, msg, msglen, &numbytes)))
@@ -72,7 +77,7 @@ static int test_quic_write_read(int idx)
 
                 ossl_quic_tserver_tick(qtserv);
 
-                if (!TEST_true(ossl_quic_tserver_read(qtserv, buf, sizeof(buf),
+                if (!TEST_true(ossl_quic_tserver_read(qtserv, sid, buf, sizeof(buf),
                                                       &numbytes)))
                     goto end;
             } while (numbytes == 0);
@@ -81,7 +86,7 @@ static int test_quic_write_read(int idx)
                 goto end;
         }
 
-        if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg,
+        if (!TEST_true(ossl_quic_tserver_write(qtserv, sid, (unsigned char *)msg,
                                                msglen, &numbytes)))
             goto end;
         ossl_quic_tserver_tick(qtserv);
index beb3e4dc41319dfb53bac483f0c0f310baf59515..fbbbad4dd68f69bbd9f6d18525fe5ff748a4667e 100644 (file)
@@ -45,7 +45,7 @@ static int test_basic(void)
         goto err;
 
     ossl_quic_tserver_tick(qtserv);
-    if (!TEST_true(ossl_quic_tserver_read(qtserv, buf, sizeof(buf), &bytesread)))
+    if (!TEST_true(ossl_quic_tserver_read(qtserv, 0, buf, sizeof(buf), &bytesread)))
         goto err;
 
     /*
@@ -119,7 +119,7 @@ static int test_unknown_frame(void)
                                                          NULL)))
         goto err;
 
-    if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg, msglen,
+    if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg, msglen,
                                            &byteswritten)))
         goto err;
 
@@ -294,7 +294,7 @@ static int test_corrupted_data(int idx)
      * Send first 5 bytes of message. This will get corrupted and is treated as
      * "lost"
      */
-    if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg, 5,
+    if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg, 5,
                                            &byteswritten)))
         goto err;
 
@@ -317,7 +317,7 @@ static int test_corrupted_data(int idx)
     OSSL_sleep(100);
 
     /* Send rest of message */
-    if (!TEST_true(ossl_quic_tserver_write(qtserv, (unsigned char *)msg + 5,
+    if (!TEST_true(ossl_quic_tserver_write(qtserv, 0, (unsigned char *)msg + 5,
                                            msglen - 5, &byteswritten)))
         goto err;