]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Arrange the remove_id() logic to be able to remove multiple stream.
authorJean-Frederic Clere <jfclere@gmail.com>
Tue, 3 Dec 2024 15:18:03 +0000 (16:18 +0100)
committerNeil Horman <nhorman@openssl.org>
Mon, 17 Feb 2025 16:27:33 +0000 (11:27 -0500)
create a new h3conn in read_from_ssl_ids() when we have a new
connection.

Reviewed-by: Neil Horman <nhorman@openssl.org>
Reviewed-by: Saša Nedvědický <sashan@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/25859)

demos/http3/ossl-nghttp3-demo-server.c

index 27113f811612ceb26a6163241bd0d06946ab5902..1aa4a38dd3b978d08227243225d1a0e0dc5fbe42 100644 (file)
 /* The crappy test wants 20 bytes */
 static uint8_t nulldata[20] = "12345678901234567890";
 
+/* The nghttp3 variable we need in the main part and read_from_ssl_ids */
+static nghttp3_settings settings;
+static const nghttp3_mem *mem;
+static nghttp3_callbacks callbacks = {0};
+
 /* 3 streams created by the server and 4 by the client (one is bidi) */
 struct ssl_id {
     SSL *s;      /* the stream openssl uses in SSL_read(),  SSL_write etc */
     uint64_t id; /* the stream identifier the nghttp3 uses */
-    int status;  /* 0, CLIENTUNIOPEN or CLIENTUNIOPEN|CLIENTCLOSED (for the moment) */
+    int status;  /* 0 or one the below status and origin */
 };
 /* status and origin of the streams the possible values are: */
 #define CLIENTUNIOPEN  0x01 /* unidirectional open by the client (2, 6 and 10) */
@@ -33,6 +38,7 @@ struct ssl_id {
 #define CLIENTBIDIOPEN 0x04 /* bidirectional open by the client (something like 0, 4, 8 ...) */
 #define SERVERUNIOPEN  0x08 /* unidirectional open by the server (3, 7 and 11) */
 #define SERVERCLOSED   0x10 /* closed by the server (us) */
+#define TOBEREMOVED    0x20 /* marked for removing in read_from_ssl_ids, removed after processing all events */
 
 #define MAXSSL_IDS 20
 #define MAXURL 255
@@ -136,18 +142,16 @@ static void add_id_at(uint64_t id, SSL *ssl, int at, struct h3ssl *h3ssl)
     exit(1);
 }
 
-static void remove_id(uint64_t id, struct h3ssl *h3ssl)
+/* remove the ids marked for removal */
+static void remove_marked_ids(struct h3ssl *h3ssl)
 {
     struct ssl_id *ssl_ids;
     int i;
 
     ssl_ids = h3ssl->ssl_ids;
-    if (id == UINT64_MAX)
-        return;
     for (i = 0; i < MAXSSL_IDS; i++) {
-        if (ssl_ids[i].id == id) {
+        if (ssl_ids[i].status & TOBEREMOVED) {
             printf("remove_id %llu\n", (unsigned long long) ssl_ids[i].id);
-            /* XXX: don't work SSL_clear(ssl_ids[i].s); */
             SSL_free(ssl_ids[i].s);
             ssl_ids[i].s = NULL;
             ssl_ids[i].id = UINT64_MAX;
@@ -157,6 +161,7 @@ static void remove_id(uint64_t id, struct h3ssl *h3ssl)
     }
 }
 
+/* add the status bytes to the status */
 static void set_id_status(uint64_t id, int status, struct h3ssl *h3ssl)
 {
     struct ssl_id *ssl_ids;
@@ -166,7 +171,7 @@ static void set_id_status(uint64_t id, int status, struct h3ssl *h3ssl)
     for (i = 0; i < MAXSSL_IDS; i++) {
         if (ssl_ids[i].id == id) {
             printf("set_id_status: %llu to %d\n", (unsigned long long) ssl_ids[i].id, status);
-            ssl_ids[i].status = status;
+            ssl_ids[i].status = ssl_ids[i].status | status;
             return;
         }
     }
@@ -200,13 +205,13 @@ static int are_all_clientid_closed(struct h3ssl *h3ssl)
     for (i = 0; i < MAXSSL_IDS; i++) {
         if (ssl_ids[i].id == UINT64_MAX)
             continue;
-        if (ssl_ids[i].status == CLIENTUNIOPEN) {
+        if (ssl_ids[i].status & CLIENTUNIOPEN) {
             printf("are_all_clientid_closed: %llu open\n", (unsigned long long) ssl_ids[i].id);
             return 0;
         }
         printf("are_all_clientid_closed: %llu status %d : %d\n",
                (unsigned long long) ssl_ids[i].id, ssl_ids[i].status, CLIENTUNIOPEN | CLIENTCLOSED);
-        if (ssl_ids[i].status == (CLIENTUNIOPEN | CLIENTCLOSED)) {
+        if (ssl_ids[i].status & (CLIENTUNIOPEN | CLIENTCLOSED)) {
             printf("are_all_clientid_closed: %llu closed\n", (unsigned long long) ssl_ids[i].id);
             SSL_free(ssl_ids[i].s);
             ssl_ids[i].s = NULL;
@@ -401,7 +406,7 @@ static int quic_server_h3streams(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
 }
 
 /* Try to read from the streams we have */
-static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
+static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
 {
     int hassomething = 0, i;
     struct ssl_id *ssl_ids = h3ssl->ssl_ids;
@@ -410,7 +415,8 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
     size_t result_count = SIZE_MAX;
     int numitem = 0, ret;
     uint64_t processed_event = 0;
-    uint64_t id_to_remove = UINT64_MAX;
+    int has_ids_to_remove = 0;
+    nghttp3_conn *h3conn = *curh3conn;
 
     /*
      * Process all the streams
@@ -473,12 +479,6 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
                 ret = -1;
                 goto err;
             }
-            if (!SSL_set_incoming_stream_policy(conn,
-                                                SSL_INCOMING_STREAM_POLICY_ACCEPT, 0)) {
-                fprintf(stderr, "error while setting inccoming stream policy\n");
-                ret = -1;
-                goto err;
-            }
 
             /* the previous might be still there */
             if (ssl_ids[1].s) {
@@ -487,8 +487,28 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
                 SSL_free(h3ssl->ssl_ids[1].s);
                 h3ssl->ssl_ids[1].s = NULL;
                 reuse_h3ssl(h3ssl);
+                close_all_ids(h3ssl);
+                h3ssl->id_bidi = UINT64_MAX;
+                h3ssl->has_uni = 0;
+                /* we need a new h3conn here!!! */
+                if (nghttp3_conn_server_new(curh3conn, &callbacks, &settings, mem,
+                                            h3ssl)) {
+                    fprintf(stderr, "nghttp3_conn_client_new failed!\n");
+                    exit(1);
+                }
+                /* XXX : nghttp3_conn_del() for the old one? */
+                h3conn = *curh3conn;
+
+                hassomething++;
                 h3ssl->new_conn = 1;
             }
+            if (!SSL_set_incoming_stream_policy(conn,
+                                                SSL_INCOMING_STREAM_POLICY_ACCEPT, 0)) {
+                fprintf(stderr, "error while setting inccoming stream policy\n");
+                ret = -1;
+                goto err;
+            }
+
             add_id_at(-1, conn, 1, h3ssl);
             printf("SSL_accept_connection\n");
             processed_event = processed_event + SSL_POLL_EVENT_IC;
@@ -515,10 +535,10 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
             }
             if (SSL_get_stream_type(stream) == SSL_STREAM_TYPE_BIDI) {
                 /* bidi that is the id  where we have to send the response */
-                printf("=> Received connection on %lld ISBIDI\n",
-                       (unsigned long long) new_id);
-                if (h3ssl->id_bidi != UINT64_MAX)
-                    id_to_remove = h3ssl->id_bidi;
+                if (h3ssl->id_bidi != UINT64_MAX) {
+                    set_id_status(h3ssl->id_bidi, TOBEREMOVED, h3ssl);
+                    has_ids_to_remove++;
+                }
                 h3ssl->id_bidi = new_id;
                 reuse_h3ssl(h3ssl);
                 h3ssl->restart = 1;
@@ -609,9 +629,8 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
             status = get_id_status(id, h3ssl);
 
             printf("revent exception READ on %llu\n", (unsigned long long)id);
-            if (status == CLIENTUNIOPEN) {
-                status = status | CLIENTCLOSED;
-                set_id_status(id, status, h3ssl);
+            if (status & CLIENTUNIOPEN) {
+                set_id_status(id, CLIENTCLOSED, h3ssl);
                 hassomething++;
             }
             processed_event = processed_event + SSL_POLL_EVENT_ER;
@@ -628,9 +647,10 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
             id = SSL_get_stream_id(item->desc.value.ssl);
             status = get_id_status(id, h3ssl);
 
-            if (status == SERVERCLOSED) {
+            if (status & SERVERCLOSED) {
                 printf("both sides closed on  %llu\n", (unsigned long long)id);
-                id_to_remove = id;
+                set_id_status(id, TOBEREMOVED, h3ssl);
+                has_ids_to_remove++;
                 hassomething++;
             }
             processed_event = processed_event + SSL_POLL_EVENT_EW;
@@ -647,8 +667,8 @@ static int read_from_ssl_ids(nghttp3_conn *h3conn, struct h3ssl *h3ssl)
     }
     ret = hassomething;
 err:
-    if (id_to_remove != UINT64_MAX)
-        remove_id(id_to_remove, h3ssl);
+    if (has_ids_to_remove)
+        remove_marked_ids(h3ssl);
     return ret;
 }
 
@@ -958,11 +978,8 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         goto err;
 
     for (;;) {
-        nghttp3_conn *h3conn;
-        nghttp3_settings settings;
-        nghttp3_callbacks callbacks = {0};
+        nghttp3_conn *h3conn = NULL;
         struct h3ssl h3ssl;
-        const nghttp3_mem *mem = nghttp3_mem_default();
         nghttp3_nv resp[10];
         size_t num_nv;
         nghttp3_data_reader dr;
@@ -1000,14 +1017,19 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         callbacks.recv_data = on_recv_data;
         callbacks.end_stream = on_end_stream;
 
+        /* mem default */
+        mem = nghttp3_mem_default();
+
         handle_events_from_ids(&h3ssl);
 
     newconn:
-        nghttp3_settings_default(&settings);
-        if (nghttp3_conn_server_new(&h3conn, &callbacks, &settings, mem,
-                                    &h3ssl)) {
-            fprintf(stderr, "nghttp3_conn_client_new failed!\n");
-            exit(1);
+        if (h3conn == NULL) {
+            nghttp3_settings_default(&settings);
+            if (nghttp3_conn_server_new(&h3conn, &callbacks, &settings, mem,
+                                        &h3ssl)) {
+                fprintf(stderr, "nghttp3_conn_client_new failed!\n");
+                exit(1);
+            }
         }
 
         printf("process_server starting...\n");
@@ -1029,7 +1051,7 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
                 }
                 handle_events_from_ids(&h3ssl);
             }
-            hassomething = read_from_ssl_ids(h3conn, &h3ssl);
+            hassomething = read_from_ssl_ids(&h3conn, &h3ssl);
             if (hassomething == -1) {
                 fprintf(stderr, "read_from_ssl_ids hassomething failed\n");
                 goto err;
@@ -1152,10 +1174,7 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
              * close stream zero
              */
             if (!h3ssl.close_done) {
-                int status = get_id_status(h3ssl.id_bidi, &h3ssl);
-
-                status = status | SERVERCLOSED;
-                set_id_status(h3ssl.id_bidi, status, &h3ssl);
+                set_id_status(h3ssl.id_bidi, SERVERCLOSED, &h3ssl);
                 h3ssl.close_wait = 1;
             }
         } else {
@@ -1177,7 +1196,7 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
                 /* we have something or a timeout */
                 handle_events_from_ids(&h3ssl);
             }
-            hasnothing = read_from_ssl_ids(h3conn, &h3ssl);
+            hasnothing = read_from_ssl_ids(&h3conn, &h3ssl);
             if (hasnothing == -1) {
                 printf("hasnothing failed\n");
                 break;
@@ -1195,7 +1214,6 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
                 }
                 if (h3ssl.new_conn) {
                     printf("hasnothing something... NEW CONN\n");
-                    close_all_ids(&h3ssl);
                     h3ssl.new_conn = 0;
                     goto newconn;
                 }
@@ -1219,6 +1237,7 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         close_all_ids(&h3ssl);
         SSL_free(h3ssl.ssl_ids[1].s);
         h3ssl.ssl_ids[1].s = NULL;
+        h3conn = NULL; /* XXX need nghttp3_conn_del() ? */
     }
 
     ok = 1;