]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Extend create_accept_stream test
authorAndrew Dinh <andrewd@openssl.org>
Mon, 30 Jun 2025 15:21:48 +0000 (22:21 +0700)
committerNeil Horman <nhorman@openssl.org>
Thu, 3 Jul 2025 00:55:24 +0000 (20:55 -0400)
- Create more options for creating server-initiated
- Check that correct stream is accepted with SSL_get_stream_type

Reviewed-by: Neil Horman <nhorman@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/27883)

test/quicapitest.c

index 2223f32575a5368970c6d039c8b7a1de77ef2249..83b0e83afb4e4a4fe7c025156aee1869a88f14a3 100644 (file)
@@ -2866,33 +2866,41 @@ static int test_ssl_set_verify(void)
 }
 
 /*
- * Creates a server-initiated stream (unidirectional if is_uni=1, bidirectional
- * otherwise) and tests that client calling SSL_accept_stream with accept_flags
- * behaves as expected.
+ * Creates server-initiated streams (1 uni stream if stream_opt=0, 1 bidi stream
+ * if stream_opt=1, and both a uni and bidi stream if stream_opt=2) and tests
+ * that client calling SSL_accept_stream with accept_flags behaves as expected.
  */
 static int create_accept_stream(SSL *serverssl, SSL *clientssl,
-                                uint64_t accept_flags, int is_uni)
+                                uint64_t accept_flags, int stream_opt)
 {
     unsigned char buf[16], msg[] = "Hello, World!";
-    SSL *clientstream = NULL, *serverstream = NULL, *stream = NULL;
-    int ret = 0, should_accept = 1;
-    uint64_t new_stream_flags = is_uni ? SSL_STREAM_FLAG_UNI : 0;
+    SSL *clientstream = NULL, *serverstream = NULL, *serverstream2 = NULL;
+    int create_uni = stream_opt != 1, create_bidi = stream_opt != 0;
+    int ret = 0, should_accept = 1, stream_type;
     size_t nread, nwritten;
 
-    if (is_uni != 0 && is_uni != 1)
+    if (stream_opt < 0 || stream_opt > 2)
         goto err;
 
     if ((accept_flags & SSL_ACCEPT_STREAM_UNI)
         && !(accept_flags & SSL_ACCEPT_STREAM_BIDI))
-        should_accept = is_uni;
+        should_accept = create_uni;
     else if ((accept_flags & SSL_ACCEPT_STREAM_BIDI)
              && !(accept_flags & SSL_ACCEPT_STREAM_UNI))
-        should_accept = !is_uni;
+        should_accept = create_bidi;
 
-    if (!TEST_ptr(serverstream = SSL_new_stream(serverssl, new_stream_flags))
-        || !TEST_int_gt(SSL_write_ex(serverstream, msg, sizeof(msg), &nwritten), 0)
-        || !TEST_int_eq(nwritten, sizeof(msg))
-        || !TEST_int_eq(SSL_handle_events(clientssl), 1))
+    if (create_uni
+        && (!TEST_ptr(serverstream = SSL_new_stream(serverssl, SSL_STREAM_FLAG_UNI))
+            || !TEST_int_gt(SSL_write_ex(serverstream, msg, sizeof(msg), &nwritten), 0)
+            || !TEST_int_eq(nwritten, sizeof(msg))
+            || !TEST_int_eq(SSL_handle_events(clientssl), 1)))
+        goto err;
+
+    if (create_bidi
+        && (!TEST_ptr(serverstream2 = SSL_new_stream(serverssl, 0))
+            || !TEST_int_gt(SSL_write_ex(serverstream2, msg, sizeof(msg), &nwritten), 0)
+            || !TEST_int_eq(nwritten, sizeof(msg))
+            || !TEST_int_eq(SSL_handle_events(clientssl), 1)))
         goto err;
 
     clientstream = SSL_accept_stream(clientssl, accept_flags);
@@ -2907,18 +2915,27 @@ static int create_accept_stream(SSL *serverssl, SSL *clientssl,
         goto err;
     }
 
+    stream_type = SSL_get_stream_type(clientstream);
+    if (!create_uni && create_bidi && stream_type != SSL_STREAM_TYPE_BIDI)
+        goto err;
+    else if (create_uni && !create_bidi && stream_type != SSL_STREAM_TYPE_READ)
+        goto err;
+
     if (!TEST_int_gt(SSL_read_ex(clientstream, buf, sizeof(buf), &nread), 0)
         || !TEST_int_eq(nread, sizeof(msg)))
         goto err;
 
     ret = 1;
 err:
-    /* In case there is a stream still in the queue */
-    stream = SSL_accept_stream(clientssl, 0);
-
     SSL_free(serverstream);
+    SSL_free(serverstream2);
+    SSL_free(clientstream);
+
+    /* In case there are still streams still in the queue (up to 2) */
+    clientstream = SSL_accept_stream(clientssl, 0);
+    SSL_free(clientstream);
+    clientstream = SSL_accept_stream(clientssl, 0);
     SSL_free(clientstream);
-    SSL_free(stream);
 
     return ret;
 }
@@ -2983,7 +3000,9 @@ static int test_accept_stream(void)
         if (!TEST_int_eq(create_accept_stream(serverssl, clientssl,
                                               accept_flags[j], 0), 1)
             || !TEST_int_eq(create_accept_stream(serverssl, clientssl,
-                                                 accept_flags[i], 1), 1)
+                                                 accept_flags[j], 1), 1)
+            || !TEST_int_eq(create_accept_stream(serverssl, clientssl,
+                                                 accept_flags[j], 2), 1)
             )
             goto err;
     }